Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme (Both Forward and Reverse) fails to differentiate the total energy of a simulation using the linear scalar advection equation wrt k #1

Open
junyixu opened this issue Sep 29, 2024 · 0 comments
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed question Further information is requested

Comments

@junyixu
Copy link
Owner

junyixu commented Sep 29, 2024

Difficulty with differentiating total energy w.r.t. wave number using Enzyme.jl and Trixi.jl:

Description:

Successfully applying Enzyme in simple cases like differentiating the sine function:

julia> autodiff(Forward, sin, Duplicated(0.0, 1.0))
(1.0,)

julia> autodiff(Reverse, sin, Active, Active(0.0))
((1.0,),)

However, when extending this approach to a more complex scenario using Trixi.jl to differentiate the total energy with respect to the wave number, I encountered an issue. Here’s the minimal working example (MWE):

MWE:

using Trixi, OrdinaryDiffEq
import Enzyme
using Enzyme: autodiff, Forward, Duplicated, Const
using Enzyme: Reverse, Active

Enzyme.API.runtimeActivity!(true)

function energy_at_final_time(k) # k is the wave number of the initial condition
    equations = LinearScalarAdvectionEquation2D(1.0, -0.3)
    mesh = TreeMesh((-1.0, -1.0), (1.0, 1.0), initial_refinement_level=3, n_cells_max=10^4)
    solver = DGSEM(3, flux_lax_friedrichs)
    initial_condition = (x, t, equation) -> begin
            x_trans = Trixi.x_trans_periodic_2d(x - equation.advection_velocity * t)
            return SVector(sinpi(k * sum(x_trans)))
    end
    semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition, solver)
    ode = semidiscretize(semi, (0.0, 1.0))
    sol = solve(ode, BS3(), save_everystep=false)
    Trixi.integrate(energy_total, sol.u[end], semi)
end

# autodiff(Forward, energy_at_final_time, Duplicated(1.0, 1.0))
autodiff(Reverse, energy_at_final_time, Active, Active(1.0))

ForwardDiff Considerations:

ForwardDiff.jl uses dual numbers, which store both the result and its derivative with respect to a specified parameter. This means we need to ensure that we can handle ForwardDiff.Dual numbers throughout the computation. In this case, we can address this by specifying the element type for uEltype in the SemidiscretizationHyperbolic call:

semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition, solver, uEltype=typeof(k))

Enzyme Considerations:

Enzyme works by operating on LLVM's intermediate representation (IR), so additional modifications might be needed. For instance, we might need to expose certain parts of the cache for differentiation to succeed.

Additionally, I received the following warning from Enzyme.jl:

Warning: Using fallback BLAS replacements for (["dgemm_64_"]), performance may be degraded

This suggests that the BLAS linear algebra library could also be causing some issues, although according to Enzyme’s documentation, this is unlikely to be the main problem.

Error:

I encountered the following error when trying to differentiate energy_at_final_time using Enzyme's reverse-mode autodiff.

ERROR: Enzyme compilation failed.
Current scope:
define internal fastcc nonnull {} addrspace(10)* @julia__TreeMesh_204_15137() unnamed_addr #54 !dbg !2616 {
top:
  %0 = alloca i64, align 16
  %1 = bitcast i64* %0 to i8*
  %2 = alloca i64, align 16
  %3 = bitcast i64* %2 to i8*
  %4 = call {}*** @julia.get_pgcstack()
  %current_task147 = getelementptr inbounds {}**, {}*** %4, i64 -14
  %current_task1 = bitcast {}*** %current_task147 to {}**
  %ptls_field48 = getelementptr inbounds {}**, {}*** %4, i64 2
  %5 = bitcast {}*** %ptls_field48 to i64***
  %ptls_load4950 = load i64**, i64*** %5, align 8, !tbaa !71
  %6 = getelementptr inbounds i64*, i64** %ptls_load4950, i64 2
  %safepoint = load i64*, i64** %6, align 8, !tbaa !75, !invariant.load !70
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint), !dbg !2617
  fence syncscope("singlethread") seq_cst
  br i1 true, label %L8, label %L4, !dbg !2618

L4:                                               ; preds = %top
  unreachable

L8:                                               ; preds = %top
  br i1 true, label %L15, label %L11, !dbg !2619

L11:                                              ; preds = %L8
  unreachable

L15:                                              ; preds = %L8
  %7 = load i8, i8* inttoptr (i64 130278523260736 to i8*), align 64, !dbg !2620, !tbaa !170, !alias.scope !174, !noalias !728
  %8 = and i8 %7, 1, !dbg !2620
  %.not = icmp eq i8 %8, 0, !dbg !2620
  %spec.select = select i1 %.not, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 130278519334512 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 130278519333632 to {}*) to {} addrspace(10)*), !dbg !2627
  %9 = load i8, i8* inttoptr (i64 130278937627225 to i8*), align 1, !dbg !2628, !tbaa !170, !alias.scope !174, !noalias !728
  %10 = and i8 %9, 1, !dbg !2628
  %.not51 = icmp eq i8 %10, 0, !dbg !2628
  br i1 %.not51, label %L53, label %L52, !dbg !2633

L52:                                              ; preds = %L15
  %11 = call fastcc nonnull {} addrspace(10)* @julia_push__15142(), !dbg !2634
  %phi.cast = addrspacecast {} addrspace(10)* %11 to {} addrspace(11)*, !dbg !2633
  br label %L53, !dbg !2633

L53:                                              ; preds = %L52, %L15
  %nodecayed.value_phi15 = phi {} addrspace(10)*
  %nodecayedoff.value_phi15 = phi i64
  %value_phi15 = phi {} addrspace(11)* [ %phi.cast, %L52 ], [ addrspacecast ({} addrspace(10)* null to {} addrspace(11)*), %L15 ]
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %1)
  store i64 0, i64* %0, align 16, !dbg !2635, !tbaa !170, !alias.scope !174, !noalias !2642
  %bitcast_coercion19 = ptrtoint i64* %0 to i64, !dbg !2645
  call void @ijl_gc_get_total_bytes(i64 noundef %bitcast_coercion19) [ "jl_roots"({} addrspace(10)* null) ], !dbg !2651
  %12 = load i64, i64* %0, align 16, !dbg !2654, !tbaa !170, !alias.scope !174, !noalias !728
  call void @llvm.lifetime.end.p0i8(i64 noundef 8, i8* noundef nonnull %1)
  %13 = call i64 @ijl_hrtime(), !dbg !2657
  %14 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* noundef nonnull @jl_f_apply_type, {} addrspace(10)* noundef null, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 130278519236736 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 130279731774080 to {}*) to {} addrspace(10)*), {} addrspace(10)* %spec.select) #33, !dbg !2660
  %15 = call noalias nonnull {} addrspace(10)* @ijl_box_int64(i64 noundef signext 10000) #31, !dbg !2660
  %box30 = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 16, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 130279504534320 to {}*) to {} addrspace(10)*)) #65, !dbg !2660
  %16 = bitcast {} addrspace(10)* %box30 to i8 addrspace(10)*, !dbg !2660
  %newstruct.sroa.0.0..sroa_cast = bitcast {} addrspace(10)* %box30 to double addrspace(10)*, !dbg !2660
  store double 0.000000e+00, double addrspace(10)* %newstruct.sroa.0.0..sroa_cast, align 8, !dbg !2660, !tbaa !131, !alias.scope !1147, !noalias !2661
  %newstruct.sroa.2.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %16, i64 8, !dbg !2660
  %newstruct.sroa.2.0..sroa_cast = bitcast i8 addrspace(10)* %newstruct.sroa.2.0..sroa_idx to double addrspace(10)*, !dbg !2660
  store double 0.000000e+00, double addrspace(10)* %newstruct.sroa.2.0..sroa_cast, align 8, !dbg !2660, !tbaa !131, !alias.scope !1147, !noalias !2661
  %box32 = call noalias nonnull dereferenceable(8) {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 130279575136160 to {}*) to {} addrspace(10)*)) #65, !dbg !2660
  %17 = bitcast {} addrspace(10)* %box32 to double addrspace(10)*, !dbg !2660
  store double 2.000000e+00, double addrspace(10)* %17, align 8, !dbg !2660, !tbaa !180, !alias.scope !174, !noalias !2642
  %18 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* noundef nonnull @ijl_apply_generic, {} addrspace(10)* nonnull %14, {} addrspace(10)* nonnull %15, {} addrspace(10)* nofree nonnull %box30, {} addrspace(10)* nofree nonnull %box32, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 130279642848480 to {}*) to {} addrspace(10)*)) #32, !dbg !2660
  br i1 %.not51, label %L137, label %ok, !dbg !2662

L91:                                              ; preds = %ok
  call fastcc void @julia_throw_inexacterror_15006(i64 zeroext %37) #66, !dbg !2663
  unreachable, !dbg !2663

L98:                                              ; preds = %ok
  store i64 %37, i64 addrspace(11)* %33, align 8, !dbg !2676, !tbaa !170, !alias.scope !174, !noalias !2642
  %19 = getelementptr inbounds i8, i8 addrspace(11)* %31, i64 16, !dbg !2677
  %20 = bitcast i8 addrspace(11)* %19 to i64 addrspace(11)*, !dbg !2677
  %21 = load i64, i64 addrspace(11)* %20, align 8, !dbg !2677, !tbaa !170, !alias.scope !174, !noalias !728
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %3)
  store i64 0, i64* %2, align 16, !dbg !2679, !tbaa !170, !alias.scope !174, !noalias !2642
  %bitcast_coercion37 = ptrtoint i64* %2 to i64, !dbg !2682
  call void @ijl_gc_get_total_bytes(i64 noundef %bitcast_coercion37) [ "jl_roots"({} addrspace(10)* null) ], !dbg !2685
  %22 = load i64, i64* %2, align 16, !dbg !2687, !tbaa !170, !alias.scope !174, !noalias !728
  %23 = sub i64 %21, %12, !dbg !2690
  %24 = add i64 %23, %22, !dbg !2692
  store i64 %24, i64 addrspace(11)* %20, align 8, !dbg !2694, !tbaa !170, !alias.scope !174, !noalias !2642
  %25 = bitcast {} addrspace(11)* %value_phi15 to i64 addrspace(11)*, !dbg !2695
  %26 = load i64, i64 addrspace(11)* %25, align 8, !dbg !2695, !tbaa !170, !alias.scope !174, !noalias !728
  %27 = add i64 %26, 1, !dbg !2697
  store i64 %27, i64 addrspace(11)* %25, align 8, !dbg !2698, !tbaa !170, !alias.scope !174, !noalias !2642
  %getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* bitcast (i8 addrspace(11)* getelementptr inbounds (i8, i8 addrspace(11)* addrspacecast (i8* inttoptr (i64 130278937627184 to i8*) to i8 addrspace(11)*), i64 24) to {} addrspace(10)* addrspace(11)*) unordered, align 8, !dbg !2699, !tbaa !170, !alias.scope !174, !noalias !728, !nonnull !70, !dereferenceable !1222, !align !1223
  %28 = addrspacecast {} addrspace(10)* %getfield to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !2703
  %arraylen_ptr = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %28, i64 0, i32 1, !dbg !2703
  %arraylen = load i64, i64 addrspace(11)* %arraylen_ptr, align 8, !dbg !2703, !tbaa !158, !range !82, !alias.scope !161, !noalias !162
  %.not53 = icmp eq i64 %arraylen, 0, !dbg !2709
  br i1 %.not53, label %L127, label %idxend, !dbg !2707

L127:                                             ; preds = %L98
  call void @llvm.lifetime.end.p0i8(i64 noundef 8, i8* noundef nonnull %3)
  call fastcc void @julia__throw_argerror_14962({} addrspace(10)* nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 130279626396464 to {}*) to {} addrspace(10)*)) #66, !dbg !2711
  unreachable, !dbg !2711

L137:                                             ; preds = %pass, %L53
  %29 = call noalias nonnull {} addrspace(10)* @ijl_box_int64(i64 noundef signext 3) #31, !dbg !2712
  %30 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* noundef nonnull @ijl_apply_generic, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 130278519429968 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %18, {} addrspace(10)* nonnull %29, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 130279573993792 to {}*) to {} addrspace(10)*), {} addrspace(10)* addrspacecast ({}* inttoptr (i64 130279573993792 to {}*) to {} addrspace(10)*)) #32, !dbg !2712
  ret {} addrspace(10)* %18, !dbg !2713

ok:                                               ; preds = %L53
  %31 = bitcast {} addrspace(11)* %value_phi15 to i8 addrspace(11)*, !dbg !2714
  %32 = getelementptr inbounds i8, i8 addrspace(11)* %31, i64 8, !dbg !2714
  %33 = bitcast i8 addrspace(11)* %32 to i64 addrspace(11)*, !dbg !2714
  %34 = load i64, i64 addrspace(11)* %33, align 8, !dbg !2714, !tbaa !170, !alias.scope !174, !noalias !728
  %35 = call i64 @ijl_hrtime(), !dbg !2715
  %36 = sub i64 %35, %13, !dbg !2716
  %37 = add i64 %36, %34, !dbg !2717
  %38 = icmp sgt i64 %37, -1, !dbg !2719
  br i1 %38, label %L98, label %L91, !dbg !2663

idxend:                                           ; preds = %L98
  %39 = add nsw i64 %arraylen, -1, !dbg !2721
  %40 = addrspacecast {} addrspace(10)* %getfield to {} addrspace(10)* addrspace(13)* addrspace(11)*, !dbg !2721
  %arrayptr54 = load {} addrspace(10)* addrspace(13)*, {} addrspace(10)* addrspace(13)* addrspace(11)* %40, align 16, !dbg !2721, !tbaa !163, !alias.scope !2724, !noalias !162, !nonnull !70
  %41 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %arrayptr54, i64 %39, !dbg !2721
  %arrayref = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %41, align 8, !dbg !2721, !tbaa !1429, !alias.scope !174, !noalias !728
  %.not55 = icmp eq {} addrspace(10)* %arrayref, null, !dbg !2721
  br i1 %.not55, label %fail, label %pass, !dbg !2721

fail:                                             ; preds = %idxend
  call void @llvm.lifetime.end.p0i8(i64 noundef 8, i8* noundef nonnull %3)
  call void @ijl_throw({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 130279573993952 to {}*) to {} addrspace(12)*)) #66, !dbg !2721
  unreachable, !dbg !2721

pass:                                             ; preds = %idxend
  call void @llvm.lifetime.end.p0i8(i64 noundef 8, i8* noundef nonnull %3)
  call void @ijl_array_del_end({} addrspace(10)* noundef nonnull %getfield, i64 noundef 1), !dbg !2725
  br label %L137, !dbg !2728
}

Could not analyze garbage collection behavior of
 inst:   %value_phi15 = phi {} addrspace(11)* [ %phi.cast, %L52 ], [ addrspacecast ({} addrspace(10)* null to {} addrspace(11)*), %L15 ]
 v0: {} addrspace(11)* addrspacecast ({} addrspace(10)* null to {} addrspace(11)*)
 v: {} addrspace(11)* addrspacecast ({} addrspace(10)* null to {} addrspace(11)*)
 offset: i64 0
 hasload: false


Stacktrace:
 [1] TreeMesh
   @ ~/WorkSpace/jacobian4DG/Differentiating_through_a_complete_simulation/dev/Trixi/src/meshes/tree_mesh.jl:107
 [2] energy_at_final_time
   @ ./REPL[14]:4

Stacktrace:
  [1] (::Enzyme.Compiler.var"#getparent#382"{…})(v::LLVM.ConstantExpr, offset::LLVM.ConstantInt, hasload::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler/optimize.jl:467
  [2] nodecayed_phis!(mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler/optimize.jl:470
  [3] optimize!
    @ ~/.julia/packages/Enzyme/l4FS0/src/compiler/optimize.jl:1531 [inlined]
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:4781
  [5] codegen
    @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:4340 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5430
  [7] _thunk
    @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5430 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5464 [inlined]
  [9] (::Enzyme.Compiler.var"#532#533"{…})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5530
 [10] JuliaContext(f::Enzyme.Compiler.var"#532#533"{…})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
 [11] #s1883#531
    @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5482 [inlined]
 [12]
    @ Enzyme.Compiler ./none:0
 [13] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [14] autodiff(::EnzymeCore.ReverseMode{…}, f::Const{…}, ::Type{…}, args::Active{…})
    @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:197
 [15] autodiff(::EnzymeCore.ReverseMode{false, EnzymeCore.FFIABI}, ::typeof(energy_at_final_time), ::Type, ::Active{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:224
 [16] top-level scope
    @ REPL[15]:1
 [17] top-level scope
    @ none:1
Some type information was truncated. Use `show(err)` to see complete types.

Request:

  • Are there specific steps needed to handle differentiation with Enzyme, particularly around cache exposure?
@junyixu junyixu added bug Something isn't working help wanted Extra attention is needed question Further information is requested labels Sep 29, 2024
@junyixu junyixu self-assigned this Sep 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant