-
-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grandient calculation fails when using a parameter-dependent SciMLOperator
#1139
Comments
I think that the I found a possible partial fix for the out-of-place case. If I consider the out-of-place ODEProblem function SciMLOperators.update_coefficients(L::ScalarOperator, u, p, t; kwargs...)
return ScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func)
end instead of the current implementation function update_coefficients!(L::ScalarOperator, u, p, t; kwargs...)
L.val = L.update_func(L.val, u, p, t; kwargs...)
nothing
end
function update_coefficients(L::ScalarOperator, u, p, t; kwargs...)
update_coefficients!(L, u, p, t; kwargs...)
L
end Then it works Zygote.gradient(my_f, 1.9) # (-0.17161488226273966,) However, it still fails for the Enzyme case and for the in-place version of Zygote (which would be much more efficient I guess). I will make a PR to SciMLOperators.jl for fix at least this case. |
I found that the in-place version works when using using LinearAlgebra
using SparseArrays
using OrdinaryDiffEq
using SciMLOperators
using Zygote
using Enzyme
using SciMLSensitivity
##
T = ComplexF64
const N = 10
const u0 = ones(T, N)
# H_tmp = rand(T, N, N)
H_tmp = sprand(T, N, N, 0.5)
const H = H_tmp + H_tmp'
const U = ScalarOperator(one(params[1]), coef) * MatrixOperator(Diagonal(H)) + MatrixOperator(Diagonal(H))
coef(a, u, p, t) = - p[1]
function my_f(params)
tspan = (0.0, 1.0)
# prob = ODEProblem{true}(U, u0, tspan, [γ], sensealg = InterpolatingAdjoint(autojacvec=false))
prob = ODEProblem{true}(U, u0, tspan, params)
sol = solve(prob, Tsit5())
return real(sol.u[end][end])
end
params = T[1]
my_f(params) # 0.25621142049273665
##
Zygote.gradient(my_f, params) But I get the following warnings during the differentiation ┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/ME3jV/src/concrete_solve.jl:67
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/ME3jV/src/concrete_solve.jl:207
(ComplexF64[-1.9743163253371472 + 0.0im],) First, how can I know the Then, although I think the problem is related to ReverseDiff.jl, I followed this page and run tspan = (0.0, 1.0)
# prob = ODEProblem{true}(U, u0, tspan, [γ], sensealg = InterpolatingAdjoint(autojacvec=false))
prob = ODEProblem{true}(U, u0, tspan, params)
u0 = prob.u0
p = prob.p
tmp2 = Enzyme.make_zero(p)
t = prob.tspan[1]
du = zero(u0)
if DiffEqBase.isinplace(prob)
_f = prob.f
else
_f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing)
end
_tmp6 = Enzyme.make_zero(_f)
tmp3 = zero(u0)
tmp4 = zero(u0)
ytmp = zero(u0)
tmp1 = zero(u0)
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
Enzyme.Duplicated(p, tmp2),
Enzyme.Const(t)) which returns nothing for every variable ((nothing, nothing, nothing, nothing),) |
Now that I better understand how Enzyme.jl works, this is not a bad thing. It means that Enzyme can differentiate the function, right? So I don't understand why I see a warning on I guess that I can try using Moreover, can I use |
If I try ERROR: MethodError: no method matching augmented_primal(::EnzymeCore.EnzymeRules.RevConfigWidth{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…}, ::Duplicated{…}, ::Const{…}, ::Const{…})
The function `augmented_primal` exists, but no method is defined for this combination of argument types.
Closest candidates are:
augmented_primal(::EnzymeCore.EnzymeRules.RevConfig, ::Const{typeof(mul!)}, ::Type{RT}, ::Annotation{<:StridedVecOrMat}, ::Const{<:Union{SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, SubArray{Tv, 2, <:SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int64}}, I}} where I<:(AbstractUnitRange{<:Integer})} where {Tv, Ti}}, ::Annotation{<:StridedVecOrMat}, ::Annotation{<:Number}, ::Annotation{<:Number}) where RT
@ Enzyme ~/.julia/packages/Enzyme/azJki/src/internal_rules.jl:732
augmented_primal(::Any, ::Const{typeof(QuadGK.quadgk)}, ::Type{RT}, ::Any, ::Annotation{T}...; kws...) where {RT, T}
@ QuadGKEnzymeExt ~/.julia/packages/QuadGK/BjmU0/ext/QuadGKEnzymeExt.jl:6
augmented_primal(::Any, ::Const{typeof(NNlib._dropout!)}, ::Type{RT}, ::Any, ::OutType, ::Any, ::Any, ::Any) where {OutType, RT}
@ NNlibEnzymeCoreExt ~/.julia/packages/NNlib/CkJqS/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl:318 |
It looks like you're hitting a missing rule in enzyme for sparse |
what's the full types of the method match failure? And what version of enzyme are you using (it is the latest)? |
LoadError: MethodError: no method matching augmented_primal(::EnzymeCore.EnzymeRules.RevConfigWidth{1, false, false, (false, false, false, false, false, false), false}, ::Const{typeof(mul!)}, ::Type{Const{Vector{ComplexF64}}}, ::Duplicated{Vector{ComplexF64}}, ::Duplicated{SparseMatrixCSC{ComplexF64, Int64}}, ::Duplicated{Vector{ComplexF64}}, ::Const{Bool}, ::Const{Bool})
The function `augmented_primal` exists, but no method is defined for this combination of argument types.
Closest candidates are:
augmented_primal(::EnzymeCore.EnzymeRules.RevConfig, ::Const{typeof(mul!)}, ::Type{RT}, ::Annotation{<:StridedVecOrMat}, ::Const{<:Union{SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, SubArray{Tv, 2, <:SparseArrays.AbstractSparseMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int64}}, I}} where I<:(AbstractUnitRange{<:Integer})} where {Tv, Ti}}, ::Annotation{<:StridedVecOrMat}, ::Annotation{<:Number}, ::Annotation{<:Number}) where RT
@ Enzyme ~/.julia/packages/Enzyme/azJki/src/internal_rules.jl:732
augmented_primal(::Any, ::Const{typeof(QuadGK.quadgk)}, ::Type{RT}, ::Any, ::Annotation{T}...; kws...) where {RT, T}
@ QuadGKEnzymeExt ~/.julia/packages/QuadGK/BjmU0/ext/QuadGKEnzymeExt.jl:6
augmented_primal(::Any, ::Const{typeof(NNlib._dropout!)}, ::Type{RT}, ::Any, ::OutType, ::Any, ::Any, ::Any) where {OutType, RT}
@ NNlibEnzymeCoreExt ~/.julia/packages/NNlib/CkJqS/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl:318` And this is my Status `~/GitHub/Research/Undef/Autodiff QuantumToolbox/Project.toml`
[6e4b80f9] BenchmarkTools v1.5.0
[13f3f980] CairoMakie v0.12.16
[b0b7db55] ComponentArrays v0.15.19
[7da242da] Enzyme v0.13.15
[f6369f11] ForwardDiff v0.10.38
[1dea7af3] OrdinaryDiffEq v6.90.1
[33c8b6b6] ProgressLogging v0.1.4
[6c2fb7c5] QuantumToolbox v0.21.5 `~/.julia/dev/QuantumToolbox`
[731186ca] RecursiveArrayTools v3.27.3
[37e2e3b7] ReverseDiff v1.15.3
⌃ [0bca4576] SciMLBase v2.61.0
⌃ [1ed8b502] SciMLSensitivity v7.71.1
[5d786b92] TerminalLoggers v0.1.7
[e88e6eb3] Zygote v0.6.73
Info Packages marked with ⌃ have new versions available and may be upgradable. |
Describe the bug 🐞
All the examples in the
SciMLSensitivity.jl
Documentation use a user-defined function for the ODEProblem. I need instead to define a parameter-dependentSciMLOperator
(e.g., aMatrixOperator
,AddedOperator
, ...), but it fails.Expected behavior
Returning the correct gradient without errors.
Minimal Reproducible Example 👇
Error & Stacktrace⚠️
Zygote Error
Enzyme Error
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()
The text was updated successfully, but these errors were encountered: