diff --git a/docs/fig/upwind_benchmark.png b/docs/fig/upwind_benchmark.png new file mode 100644 index 0000000..7183488 Binary files /dev/null and b/docs/fig/upwind_benchmark.png differ diff --git a/docs/src/index.md b/docs/src/index.md index 51eba78..529cdd3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -11,6 +11,7 @@ TrixiEnzyme is not a registered Julia package, and it can be installed by runnin ## Configuring Batch Size +To utilize `Enzyme.BatchDuplicated`, one can create a tuple containing duals (or shadows). TrixiEnzyme.jl performs partial derivative evaluation on one "batch" of the input vector at a time. Each differentiation of a batch requires a call to the target function as well as additional memory proportional to the square of the batch's size. Thus, a smaller batch size makes better use of memory bandwidth at the cost of more calls to the target function, @@ -44,5 +45,5 @@ julia> @time jacobian_enzyme_forward(TrixiEnzyme.upwind!, x); ``` Benchmark for a 401x401 Jacobian of `TrixiEnzyme.upwind!` (Lower is better): -![upwind](https://private-user-images.githubusercontent.com/40481594/358694436-21588007-8469-46c5-8b77-e349b27c1c19.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjQ1MTQ2ODMsIm5iZiI6MTcyNDUxNDM4MywicGF0aCI6Ii80MDQ4MTU5NC8zNTg2OTQ0MzYtMjE1ODgwMDctODQ2OS00NmM1LThiNzctZTM0OWIyN2MxYzE5LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MjQlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODI0VDE1NDYyM1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTVhZjMyZTRkODc1MTcxNzYxNTJlN2M4ZmMxYjUzNWFjZTc1MjBlOTYwOGVjMTAzNzM4YTEyNjA3YzUxMzkzMTImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.lr66vXaVZMmUnWw24uh30-u754ckRPYzBctskuEntJc) +![upwind benchmark](../fig/upwind_benchmark.png) `Enyme(@batch)` means applying `Polyester.@batch` to `middlebatches`. diff --git a/src/TrixiEnzyme.jl b/src/TrixiEnzyme.jl index 422f38f..2df2696 100644 --- a/src/TrixiEnzyme.jl +++ b/src/TrixiEnzyme.jl @@ -10,12 +10,12 @@ export plusTwo, jacobian_enzyme_forward, jacobian_enzyme_forward_closure export autodiff, Forward, Reverse, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Const import Trixi -using Trixi: AbstractEquations, TreeMesh, DGSEM, +using Trixi: AbstractEquations, TreeMesh, DGSEM, jacobian_ad_forward BoundaryConditionPeriodic, SemidiscretizationHyperbolic, - VolumeIntegralWeakForm, VolumeIntegralFluxDifferencing, + VolumeIntegralWeakForm, VolumeIntegralFluxDifferencing, wrap_array, compute_coefficients, have_nonconservative_terms, - boundary_condition_periodic, - set_log_type, set_sqrt_type + boundary_condition_periodic, LinearScalarAdvectionEquation1D + set_log_type, set_sqrt_type, initial_condition_sine_wave, SVector import Enzyme using Enzyme: autodiff, Forward, Reverse, Duplicated, DuplicatedNoNeed, make_zero, BatchDuplicated, BatchDuplicatedNoNeed, Const using Polyester: @batch diff --git a/test/SemiTest.jl b/test/SemiTest.jl index ac5d37a..958039b 100644 --- a/test/SemiTest.jl +++ b/test/SemiTest.jl @@ -1,6 +1,5 @@ module SemiTest using Test -using Trixi using TrixiEnzyme # %% diff --git a/test/UpwindTest.jl b/test/UpwindTest.jl new file mode 100644 index 0000000..0b7b210 --- /dev/null +++ b/test/UpwindTest.jl @@ -0,0 +1,33 @@ +module UpwindTest +using Test +using TrixiEnzyme +using TrixiEnzyme: upwind! +using ForwardDiff + +x = -1:0.005:1 +batch_size = 2 +jacobian_enzyme_forward(TrixiEnzyme.upwind!, x, N=batch_size) + +Δt₁ = @elapsed J1 = jacobian_enzyme_forward(TrixiEnzyme.upwind!, x) + +Δt₂ = @elapsed begin +u = zeros(length(x)) +du = zeros(length(x)) +cfg = ForwardDiff.JacobianConfig(nothing, du, u) +uEltype = eltype(cfg) +nan_uEltype=convert(uEltype, NaN) +numerical_flux=fill(nan_uEltype, length(u)) + +J2 = ForwardDiff.jacobian(du, u, cfg) do du_ode, u_ode + upwind!(du_ode, u_ode, (;v=1.0, numerical_flux)); +end; + +end; # 0.326764 seconds (1.09 M allocations: 75.013 MiB, 7.28% gc time, 99.87% compilation time) + +@test J1 = J2 + +@info "Compare the time consumed by Enzyme.jl Δt and ForwardDiff.jl..." +println("Enzyme: ", Δt₁) +println("ForwardDiff: ", Δt₂) + +end