From c0214eb3023195ebc20b0d26479f227f84b00bbd Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 13 Dec 2024 15:06:40 +0100 Subject: [PATCH] caching! --- src/ControlFlow.jl | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 1831a48a..180171d4 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -141,11 +141,9 @@ function ReactantCore.traced_call(f, args...) track_numbers=(), # TODO: track_numbers? ) - if haskey(Reactant.Compiler.callcache[], cache_key) && false - @info "Cache hit" - else - Reactant.Compiler.callcache[][cache_key] = nothing - @warn Reactant.Compiler.callcache[] + if haskey(Reactant.Compiler.callcache[], cache_key) + # Determine `linear_args`, the vector containing `MLIR.IR.Value`s + # to be passed to the function: N = length(args) seen_args = Reactant.OrderedIdDict() traced_args = ntuple(N) do i @@ -163,26 +161,29 @@ function ReactantCore.traced_call(f, args...) v isa TracedType || continue push!(linear_args, v.mlir_data) end - end - - f_name = String(gensym(Symbol(f))) - temp = Reactant.make_mlir_fn( - f, - args, - (), - f_name, - false; - no_args_in_result=true, - ) - - @warn temp - - traced_result, ret, linear_result = temp[[3, 6, 9]] + # cache lookup: + (; f_name, mlir_result_types, traced_result) = Reactant.Compiler.callcache[][cache_key] + else + f_name = String(gensym(Symbol(f))) + temp = Reactant.make_mlir_fn( + f, + args, + (), + f_name, + false; + no_args_in_result=true, + ) + traced_result, ret, linear_args = temp[[3, 6, 7]] + linear_args = MLIR.IR.Value[v.mlir_data for v in linear_args] + mlir_result_types = [MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)] + + Reactant.Compiler.callcache[][cache_key] = (; f_name, mlir_result_types, traced_result) + end call_op = MLIR.Dialects.func.call( linear_args; - result_0=[MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)], + result_0=mlir_result_types, callee=MLIR.IR.FlatSymbolRefAttribute(f_name), ) @@ -195,10 +196,10 @@ function ReactantCore.traced_call(f, args...) toscalar=false, track_numbers=(), ) - linear_results = TracedType[] i = 1 for (k, v) in seen_results v isa TracedType || continue + # this mutates `traced_result`, which is what we want: v.mlir_data = MLIR.IR.result(call_op, i) i += 1 end