-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Fix AD issues with various kernels #154
Conversation
I think the reason for failure of Zygote with |
@devmotion Do you suggest I override their pairwise implementation for now or open an issue/PR to |
I guess that should/could be resolved by adding a custom ChainRules-based adjoint for |
Sorry for the delay. I am having a hard time defining adjoints which aren't very computationally expensive for the Maha kernel. |
I guess you shouldn't need anything in particular for the Mahalanobis kernel but rather (just) a custom adjoint for the distance computations in Distances? In this case the Matrix cookbook, and in particular equations 72 and 81 are helpful. They show you that BTW I just noticed that the docstring of the kernel is incorrect since the distance computation (
P but not the inverse of P (Distances doesn't use the inverse either, according to their README).
|
@devmotion Thanks for pointing out the typo in the docstring! |
src/zygote_adjoints.jl
Outdated
a_b = map( | ||
x -> (first(last(x)) - last(last(x)))*first(x), | ||
zip( | ||
Δ, | ||
Iterators.product(eachslice(a, dims=dims), eachslice(b, dims=dims)) | ||
) | ||
) | ||
δa = reduce(hcat, sum(map(x -> B_B_t*x, a_b), dims=1)) | ||
δB = sum(map(x -> x*transpose(x), a_b)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would assume it should be possible to vectorize this code? What's the mathematical formula that you use here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is the same equations you mentioned earlier.
d((x-y)'*Q*(x-y))/dx = (Q + Q') * (x - y)
, d((x-y)'*Q*(x-y))/dy = - (Q + Q') * (x - y)
, and d((x-y)'*Q*(x-y))/dQ = (x - y)' * (x - y)
.
But this is being done for all pairwise combinations together using map
. It later sums these differences to get \deltaB
and others.
Please note that the current implementation is not correct. I am still debugging it. (it is only partially matching the intended result) If you happen to find any obvious mistakes please let me know. I am facing trouble in reducing the results of individual pairwise pullbacks to the final pullback. The way I am summing them is probably wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> using Distances, Random;
julia> rng = MersenneTwister(123);
julia> M1, M2 = rand(rng, 2,3), rand(rng, 2,3);
julia> dist = SqMahalanobis(rand(rng, 2,2))
SqMahalanobis{Float64}([0.8654121434083455 0.2856979003853177; 0.617491887982287 0.46384720826189474])
julia> pairwise(dist, M1, M2; dims=2)
3×3 Array{Float64,2}:
0.371673 0.856348 0.742803
0.0233992 0.274278 0.276694
-0.036568 0.118487 0.0748149
julia> map(x -> evaluate(dist, first(x), last(x)), Iterators.product(eachslice(M1, dims=2), eachslice(M2, dims=2)))
3×3 Array{Float64,2}:
0.541253 0.912421 0.673273
0.0886328 0.285181 0.192394
0.0868399 0.166227 0.0616321
@devmotion isn't this wrong or have I done something silly? They are equal in case of euclidean. I feel this is the root of the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should work if dist.qmat
is positive definite: JuliaStats/Distances.jl#174
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still does not solve the differences in the computed adjoints for the covariance matrix Q. My current implementation matches the second adjoint.
julia> using Distances, LinearAlgebra, FiniteDifferences, Random
julia> FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
julia> rng = MersenneTwister(123);
julia> M1, M2 = rand(rng,3,1), rand(rng,3,1)
([0.7684476751965699; 0.940515000715187; 0.6739586945680673], [0.3954531123351086; 0.3132439558075186; 0.6625548164736534])
julia> Q = Matrix(Cholesky(rand(rng, 3, 3), 'U', 0))
3×3 Array{Float64,2}:
0.343422 0.0638007 0.507151
0.0638007 0.0386393 0.19528
0.507151 0.19528 1.21186
julia> isposdef(Q)
true
julia> dist = SqMahalanobis(Q);
julia> fdm=FiniteDifferences.Central(5, 1);
julia> j′vp(fdm, pairwise, ones(1,1), dist, M1, M2)[1].qmat #A
3×3 Array{Float64,2}:
0.139125 0.365187 -0.238366
0.102751 0.393469 -0.404876
0.246873 0.419183 0.000130048
julia> j′vp(fdm, evaluate, 1, dist, M1[:, 1], M2[:, 1])[1].qmat #B
3×3 Array{Float64,2}:
0.139125 0.233969 0.00425358
0.233969 0.393469 0.00715332
0.00425358 0.00715332 0.000130048
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it is best if (Sq)Mahalanobis distance is actually parameterized by the decomposition of Q, i.e, the upper or lower triangular matrix which is not constrained.
Yes, that would be the most natural way to ensure that it is always positive semi-definite (if the diagonal is non-negative) and optimization is performed in the correct space. So I guess users would want to use this parameterization even if it is not enforced by KernelFunctions and not directly supported by SqMahalanobis by using something like
function mykernel(L)
idxs = diagind(L)
@inbounds for i in idxs
L[i] = softplus(L[i])
end
return MahalanobisKernel(Array(L * L'))
end
Of course, it would be nice if (Sq)Mahalanobis would support specifying e.g. a Cholesky decomposition or PDMat directly (it could even be used for simplifying the computations since x'*Q*x = (L'*x)'*(L'*x)
in this case), but can't we work around this by checking gradients of the mykernel
setup instead of computing Q -> MahalanobisKernel(Q)
directly? That's at least how we do it in DistributionsAD, e.g. in https://github.com/TuringLang/DistributionsAD.jl/blob/a96b159ab25aab67d1a2076726e8b9c392eb6fc7/test/ad/distributions.jl#L18-L34.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but can't we work around this by checking gradients of the
mykernel
setup instead of computingQ -> MahalanobisKernel(Q)
directly?
Yeah that should work. Will try that out.
Regarding the issue with pairwise implementation which messes up FiniteDifferences
results, do you suggest I override the implementation for the time being?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you test the suggested parameterization the implementation of pairwise
shouldn't matter (since we do not test the intermediate step which might be affected by it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. Could we also change our side of the parametrization? i.e, the way it is stored in the struct. We could continue to allow initialization using a full matrix. This should allow for seamless AD regardless of how the user decides to initialize them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we want to do that, I think this deserves some discussion first (and then a separate PR possibly). Ideally, Distances would just support arbitrary matrices and contain optimized implementations for specific array types. We just forward P
to SqMahalanobis, so ideally we wouldn't perform any transformations or computations. I'm also a bit worried that focusing on a specific parameterization might make it difficult for users who would like to use a different one (but still no dense matrix) or might lead to confusing behaviour.
test/basekernels/maha.jl
Outdated
@test_broken j′vp(fdm, x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2]) ≈ | ||
Zygote.pullback(x -> MahalanobisKernel(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1) | ||
@test all(j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1] .≈ | ||
Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@devmotion I tried doing what you suggested. The tests still fail. This error probably propagates and causes even the first test to fail.
julia> j′vp(fdm, x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), 1, [U, v1, v2])[1][1]
3×3 UpperTriangular{Float64,Array{Float64,2}}:
0.228808 0.00318764 -0.107503
⋅ -0.000391803 0.0132135
⋅ ⋅ 0.0438772
julia> Zygote.pullback(x -> SqMahalanobis(Array(x[1]'*x[1]))(x[2], x[3]), [U, v1, v2])[2](1)[1][1]
3×3 Array{Float64,2}:
0.228808 0.00318764 -0.107503
-0.0281234 -0.000391803 0.0132135
-0.0933875 -0.00130103 0.0438772
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me your output indicates that it basically works apart from the fact that Zygote incorrectly returns a dense matrix instead of an upper triangular matrix. Since U
was upper triangular, only the values above and on the diagonal should be returned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FiniteDifferences
if pretty good in matching the types. Zygote isn't. Do you suggest we manually check if the upper triangular part matches for now?
Edit: I don't we are addressing the major issue here. Our goal is to make the overall adjoint correct for kernelmatrix
. So maybe defining a custom zygote adjoint for UpperTriangular
which outputs a UpperTriangular
might solve the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Were the call to UpperTriangular
inside the function, then the adjoint that you would get from Zygote
would also be UpperTriangular
. Maybe just do that?
test/basekernels/maha.jl
Outdated
fdm = FiniteDifferences.Central(5, 1); | ||
|
||
|
||
FiniteDifferences.to_vec(dist::SqMahalanobis{Float64}) = vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed? If possible, we should avoid this type piracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes j′vp
only works when there is a to_vec
function defined for each argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering since according to the docs to_vec
is only needed for the inputs xs...
but not the evaluated function f
in j'vp(fdm, f, xs...)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understand, it is also needed for objects like SqMahalanobis
if they have parameters like qmat
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct, but actually for some reason we've not made FiniteDifferences
handle functions-with-data properly yet, so you'll have to build the SqMaha
object inside of the function that you're differentiating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow the solution to have to define kernelmatrix
again for the NeuralNetworkKernel
seems very hacky, isn't there another solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good. Just some style things.
src/zygote_adjoints.jl
Outdated
) | ||
δa = reduce(hcat, sum(map(x -> B_Bᵀ*x, a_b), dims=2)) | ||
δB = sum(map(x -> x*transpose(x), a_b)) | ||
return (qmat=δB,), δa, -δa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is som discrepancy between the simple case above and this pullback - intuitively, from the simple case above I would assume that δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}
. However, here you compute δB = sum_{i, j} (a_i - b_j) * (a_i - b_j)^T * Δ_{i,j}^2
. Probably one of them is incorrect (table 7 in https://notendur.hi.is/jonasson/greinar/blas-rmd.pdf indicates that the pairwise one is incorrect). Can we add the derivation of the adjoints according to https://www.juliadiff.org/ChainRulesCore.jl/dev/arrays.html as docstrings or comments, or maybe even have a separate PR for the Mahalanobis fixes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out. I think a separate PR for mahalanobis fixes makes more sense.
I guess one could define a "PreMetric" that evaluates |
Can we merge this and tackle each of the remaining AD issues in separate PRs? It is getting increasingly tricky to address multiple issues at once. Currently this PR does the following:
|
IMO this PR contains already too many changes, we should just focus on one AD problem at a time.
I thought the idea was not include these adjoints since they were missing a clean derivation/documentation and were incorrect? Or are you talking about the non-pairwise adjoints only? |
I meant only the non-pairwise adjoint . I will be removing the pairwise adjoints for now. |
Any objections to merging this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no objections other than these tiny style-related things. This is a great PR.
#116