Skip to content

Commit

Permalink
add enable_tracing scopedvalue
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Dec 13, 2024
1 parent 17c6c7b commit 998ef8a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
11 changes: 7 additions & 4 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ module ReactantCore
using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

using Base.ScopedValues
const enable_tracing = ScopedValue{Bool}(false)

export @trace, MissingTracedValue

# Traits
Expand Down Expand Up @@ -181,7 +184,7 @@ function trace_for(mod, expr)
end

return quote
if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
if $(enable_tracing)[] && $(any)($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
$(reactant_code_block)
else
$(expr)
Expand All @@ -195,7 +198,7 @@ function trace_if_with_returns(mod, expr)
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
)
return quote
if any($(is_traced), ($(all_check_vars...),))
if $(enable_tracing)[] && $(any)($(is_traced), ($(all_check_vars...),))
$(new_expr)
else
$(expr)
Expand Down Expand Up @@ -341,7 +344,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
)

return quote
if any($(is_traced), ($(all_check_vars...),))
if $(enable_tracing)[] && $(any)($(is_traced), ($(all_check_vars...),))
$(reactant_code_block)
else
$(original_expr)
Expand All @@ -353,7 +356,7 @@ function trace_call(mod, expr)
f = expr.args[1]
args = expr.args[2:end]
return quote
if any($(is_traced), ($(args...), ))
if $(enable_tracing)[]
$(traced_call)($f, $(args...))
else
$(expr)
Expand Down
5 changes: 3 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import ..Reactant:
TracedToConcrete,
append_path,
TracedType
using Base.ScopedValues
import ReactantCore: enable_tracing

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
Expand Down Expand Up @@ -287,15 +289,14 @@ function compile_mlir(f, args; kwargs...)
end
end

using Base.ScopedValues
const callcache = ScopedValue{Dict}()

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
fnwrapped,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = MLIR.IR.mmodule!(mod) do
MLIR.IR.block!(MLIR.IR.body(mod)) do
with(callcache=>Dict()) do
with(enable_tracing=>true, callcache=>Dict()) do
return Reactant.make_mlir_fn(f, args, (), "main", true)
end
end
Expand Down

0 comments on commit 998ef8a

Please sign in to comment.