Skip to content

Commit

Permalink
chore:
Browse files Browse the repository at this point in the history
1. test common
2. docs
  • Loading branch information
junyixu committed Aug 24, 2024
1 parent 5536e43 commit bb56efa
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 6 deletions.
3 changes: 0 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
using Documenter
using TrixiEnzyme
import Trixi
import Enzyme
import Polyester

DocMeta.setdocmeta!(TrixiEnzyme, :DocTestSetup, :(using TrixiEnzyme); recursive=true)

Expand Down
3 changes: 2 additions & 1 deletion src/TrixiEnzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The integration of Trixi.jl with Compiler-Based (LLVM level) automatic different
module TrixiEnzyme

export plusTwo, jacobian_enzyme_forward, jacobian_enzyme_forward_closure
export autodiff, Forward, Reverse, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed

using Trixi: AbstractEquations, TreeMesh, DGSEM,
BoundaryConditionPeriodic, SemidiscretizationHyperbolic,
Expand All @@ -15,7 +16,7 @@ using Trixi: AbstractEquations, TreeMesh, DGSEM,
boundary_condition_periodic,
set_log_type, set_sqrt_type
import Enzyme
using Enzyme: autodiff, Forward, Reverse, Duplicated, DuplicatedNoNeed, make_zero
using Enzyme: autodiff, Forward, Reverse, Duplicated, DuplicatedNoNeed, make_zero, BatchDuplicated, BatchDuplicatedNoNeed
using Polyester: @batch

"""
Expand Down
66 changes: 66 additions & 0 deletions test/CommonTest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
module CommonTest

using Test
using TrixiEnzyme
using TrixiEnzyme:Enzyme
using ForwardDiff

# %%

function foo!(y, x, cache)
cache .= x.^2
y .= 2cache
return nothing
end

function foo!(y, x)
y .= 2*(x.^2)
return nothing
end
# %%


x = ones(3)
y = similar(x)
cache = similar(x)

dx = Enzyme.onehot(x)
dy = ntuple(_->zeros(size(x)), length(x))
dcache = ntuple(_->zeros(size(x)), length(x))

dcache1 = similar(x)
dx1 = zero(x)
dx1[1] = 1.0
dy1 = similar(x)

autodiff(Forward, foo!, BatchDuplicated(y, dy), BatchDuplicated(x, dx))
J1 = stack(dy)

autodiff(Forward, foo!, BatchDuplicated(y, dy), BatchDuplicated(x, dx), BatchDuplicatedNoNeed(cache, dcache))
J2 = stack(dy)

autodiff(Forward, foo!, BatchDuplicated(y, dy), BatchDuplicated(x, dx), BatchDuplicated(cache, dcache))
J3 = stack(dy)

autodiff(Forward, foo!, Duplicated(y, dy1), Duplicated(x, dx1), DuplicatedNoNeed(cache, dcache1))


# %%
@info "testing foo!(y, x, cache)"


# cfg = ForwardDiff.JacobianConfig(nothing, y, x, ForwardDiff.Chunk(2))
cfg = ForwardDiff.JacobianConfig(nothing, y, x)
uEltype = eltype(cfg)
nan_uEltype=convert(uEltype, NaN)
cache=fill(nan_uEltype, length(x))

J = ForwardDiff.jacobian(y, x, cfg) do y,x
foo!(y, x, cache)
end

@test J == J1 == J2 == J3

@test J[:, 1] == dy1

end
237 changes: 236 additions & 1 deletion test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,200 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "71d91126b5a1fb1020e1098d9d492de2a4438fd2"
project_hash = "9ff080768133d7f8c1732f8504c8ba4c64808f15"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.CommonSubexpressions]]
deps = ["MacroTools"]
git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.3.1"

[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.0+0"

[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.DiffResults]]
deps = ["StaticArraysCore"]
git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "1.1.0"

[[deps.DiffRules]]
deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.15.1"

[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.9.3"

[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"

[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"

[[deps.ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"]
git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.36"

[deps.ForwardDiff.extensions]
ForwardDiffStaticArraysExt = "StaticArrays"

[deps.ForwardDiff.weakdeps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[deps.IrrationalConstants]]
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2"

[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.5.0"

[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.6.4"

[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.4.0+0"

[[deps.LibGit2]]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
version = "1.6.4+0"

[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
version = "1.11.0+1"

[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.28"

[deps.LogExpFunctions.extensions]
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
LogExpFunctionsInverseFunctionsExt = "InverseFunctions"

[deps.LogExpFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.13"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+1"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2023.1.10"

[[deps.NaNMath]]
deps = ["OpenLibm_jll"]
git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "1.0.2"

[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+4"

[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+2"

[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0"

[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.10.0"

[[deps.Preferences]]
deps = ["TOML"]
git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.4.3"

[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -29,6 +207,63 @@ version = "0.7.0"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[deps.SpecialFunctions]]
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.4.0"

[deps.SpecialFunctions.extensions]
SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"

[deps.SpecialFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[[deps.StaticArraysCore]]
git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682"
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
version = "1.4.3"

[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.3"

[[deps.Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.0"

[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+1"

[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.52.0+1"

[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+2"
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10 changes: 9 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
using Test, TrixiEnzyme

@test out == 5
@testset "TrixiEnzyme.jl" begin
t0 = time()
@testset "Common" begin
println("##### Testing Common...")
t = @elapsed include("CommonTest.jl")
println("##### done (took $t seconds).")
end
println("##### Running all TrixiEnzyme tests took $(time() - t0) seconds.")
end

0 comments on commit bb56efa

Please sign in to comment.