Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Known AD Failures #116

Open
8 of 15 tasks
theogf opened this issue May 16, 2020 · 27 comments
Open
8 of 15 tasks

Known AD Failures #116

theogf opened this issue May 16, 2020 · 27 comments
Labels
help wanted Extra attention is needed
Milestone

Comments

@theogf
Copy link
Member

theogf commented May 16, 2020

Here is a list of the failures in the tests made in #114 I observed with the different ADs : ForwardDiff.jl, Zygote.jl and ReverseDiff.jl :

This is a good starting point to try to find solutions

@theogf theogf added the help wanted Extra attention is needed label Jun 15, 2020
@devmotion
Copy link
Member

The Zygote problems with MaternKernel are caused by the fact that the partial derivative of besselk with respect to the first argument is defined as NaN in https://github.com/JuliaDiff/ChainRules.jl/blob/98c54587257b86cce6eb45f7870a75f897058d21/src/rulesets/packages/SpecialFunctions.jl#L46-L47 (and I assume the same problem exists for the other AD backends, since I get NaN for all of them when I try to run the commented out AD tests). I guess one would have to implement https://dlmf.nist.gov/10.38 to fix it.

@theogf
Copy link
Member Author

theogf commented Jun 16, 2020

Haha writing these derivatives sounds like one should write a whole package about bessel functions

@yebai
Copy link
Contributor

yebai commented Jun 24, 2020

@sharanry Can you prioritise these AD issues? It would be great if these issues can be addressed during the summer.

@devmotion
Copy link
Member

BTW I found some publication from 2016 with closed-form expressions of the derivatives of the Bessel functions with respect to the order. I opened an issue at JuliaDiff/ChainRules.jl#208 to discuss how one would deal with the additional dependencies needed for their implementations (they contain hypergeometric functions).

@devmotion
Copy link
Member

We might want to refactor KernelSum and KernelProd (making them concretely typed and allowing both tuples and vectors of kernels similar to TensorProduct, and probably removing the weights in KernelSum) before fixing any AD issues there.

@theogf
Copy link
Member Author

theogf commented Jun 24, 2020

Agreed!
There is also a general AD issue when using Transform where the pullback on ColVecs and RowVecs return a vector of vectors, this would tick off a good portions of the issues.

@sharanry
Copy link
Contributor

sharanry commented Aug 2, 2020

@sharanry Can you prioritise these AD issues? It would be great if these issues can be addressed during the summer.

Sorry for the late reply. I somehow didn't get a notification this comment. Randomly found this while browsing the issues. I am looking into it.

@sharanry
Copy link
Contributor

sharanry commented Aug 8, 2020

The probable reason Zygote fails for FunctionTransform is the usage of Base.mapslices in

_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
.

Base.mapslices is mutating the array. Not sure why. JuliaLang/julia#17266 should have fixed this.

julia> Zygote._pullback(x-> mapslices(x->sin.(x), x, dims=1), rand(3,3))[2](ones(3,3))
ERROR: Mutating arrays is not supported

@devmotion
Copy link
Member

mapslices still mutates a temporary array. The linked PR just ensures that

mapslices never modifies the input array. It allocates temporary storage and copies each slice into it before calling the user-function.

@sharanry
Copy link
Contributor

sharanry commented Aug 8, 2020

mapslices still mutates a temporary array. The linked PR just ensures that

Oh makes sense. Do you see any other efficient way to apply a function transform for a matrix/ColVecs/RowVecs?

@devmotion
Copy link
Member

I see the following possibilities here:

function _map(t::FunctionTransform, x::ColVecs)
    vals = map(axes(x.X, 2)) do i
        t.f(view(x.X, :, i))
    end
    return ColVecs(vals)
end

(Zygote should support this automatically)

@sharanry
Copy link
Contributor

sharanry commented Aug 8, 2020

I just ran a quick benchmark for one other possibility which would require us to define adjoint for a generator. The methods you mentioned are probably better.

julia> @btime hcat(map(x->sin.(x), (eachslice(rand(1000,1000); dims=1)))...)
  16.586 ms (2015 allocations: 23.09 MiB)
julia> @btime mapslices(x->sin.(x), rand(1000,1000); dims=1)
  12.189 ms (7505 allocations: 23.18 MiB)

@devmotion
Copy link
Member

A bit off topic, but splatting probably impacts performance quite a bit, so probably it would b better to use mapreduce(x -> sin.(x), hcat, ...). For benchmarks you also want to use $(rand(1000, 1000)) (in that way the timings are unaffected by the calls of rand).

@sharanry
Copy link
Contributor

sharanry commented Aug 8, 2020

A bit off topic, but splatting probably impacts performance quite a bit, so probably it would b better to use mapreduce(x -> sin.(x), hcat, ...). For benchmarks you also want to use $(rand(1000, 1000)) (in that way the timings are unaffected by the calls of rand).

Thanks! Wasn't aware of this.
This however gave unexpected results for mapreduce.

julia> @btime mapslices(x->sin.(x), $(rand(1000,1000)); dims=1);
  10.581 ms (7503 allocations: 15.55 MiB)

julia> @btime mapreduce(x->sin.(x), hcat, eachslice($(rand(1000,1000)); dims=1));
  914.970 ms (5002 allocations: 3.74 GiB)

@devmotion
Copy link
Member

Shouldn't you use eachslice(...; dims=2) or eachcol?

@sharanry
Copy link
Contributor

sharanry commented Aug 8, 2020

I don't think it is making much difference performance wise at least.

julia> @btime mapreduce(x->sin.(x), hcat, eachslice($(rand(1000,1000)); dims=2));
  952.564 ms (5002 allocations: 3.74 GiB)

@devmotion
Copy link
Member

Can you check if the function is typestable? I suspect it might not, which would explain the number of allocations. The problem might be that it returns a different type if eachslice(...) is empty. Specifying an init kwarg might be helpful.

@willtebbutt
Copy link
Member

Just for the record -- we should be using ChainRulesCore to define pullbacks for Zygote, and ChainRulesTestUtils to test those implementations -- see e.g. here for example usage.

Plans are in the works to transfer both Tracker and ReverseDiff over to use ChainRules at some point (we know how we're going to do it, just waiting for code to get written), so this will future-proof AD in the package.

@yiyuezhuo
Copy link

I can't figure out how to define mapslices adjoint only in ChainRulesCore. If f is an anonymous function, how can we get its backward (rrule) from ChainRulesCore, while it can be done in Zygote using gradient?

@devmotion
Copy link
Member

KernelFunctions doesn't use mapslices anymore, so for this projects custom adjoints for mapslices are not required anymore. Nevertheless, an implementation of an adjoint of mapslices in ChainRules requires a solution to JuliaDiff/ChainRulesCore.jl#68 AFAICT.

@yiyuezhuo
Copy link

yiyuezhuo commented Aug 30, 2020

I see. I checked the source of the last release to backport TransformedKernel to Stheno as Stheno doesn't support KernelFunctions and found those mapslices code. But after thinking how to implement it in ChainRules or Zygote, I just disable a check in Stheno to re-enable gradient of f.(ColVecs(X)) since I feel ChainRules or Zygote will not too much difference.

@devmotion
Copy link
Member

The latest releases don't use mapslices, it was replaced in #152. I guess you can use something similar instead of mapslices in Stheno as well.

@sharanry
Copy link
Contributor

Regarding FBMKernel not working with ForwardDiff. It seems to be producing NaN values incorrectly. According to the ForwardDiff documentation, the fix for this is to "enable ForwardDiff's NaN-safe mode by setting the NANSAFE_MODE_ENABLED constant to true in ForwardDiff's source". They are currently not allowing users to enable it dynamically [Issue].

@devmotion
Copy link
Member

You can use JuliaDiff/ForwardDiff.jl#451 if you do not want to edit the source code.

@cgeoga
Copy link

cgeoga commented Feb 8, 2022

Hi--it just occurs to me to share this here, but I recently finished a project for computing derivatives of besselk with respect to the order parameter precisely for the purpose of fitting Matern covariances (example Matern kernel implementation here). The strategy that worked ended up being a re-implementation of besselk in Julia that admits fast and accurate AD derivatives with ForwardDiff.jl. The re-implemented besselk itself is not quite as accurate as the AMOS one linked in SpecialFunctions, but the derivatives are pretty accurate. Not quite to machine double precision, but reasonably close. And very fast.

I'm not sure how helpful this is because the derivatives are at present pretty ForwardDiff-specific. I would guess that it would be possible to reach compatibility with other AD tools, perhaps at a slight cost of performance by eliminating some special branches in the current implementation, but I honestly don't understand how Zygote works at all so I can't promise it.

Anyways, just writing here in case it is helpful.

@devmotion
Copy link
Member

devmotion commented Feb 8, 2022

I came across https://www.tandfonline.com/doi/pdf/10.1080/10652469.2016.1164156 a while ago, it contains closed-form expressions of the derivatives using e.g. hypergeometric functions. In principle these could be used with other AD backends as well but I don't know if there are any numerical problems, how slow/fast the evaluation with HypergeometricFunctions would be, and if (I assume not since it would introduce a circular dependency) SpecialFunctions would take a dependency on HypergeometricFunctions.

@cgeoga
Copy link

cgeoga commented Feb 8, 2022

I also saw that paper and was interested in just using that before undertaking a more from-scratch approach. But there are a few challenges with using the representations in Santander. For one, as you point out, evaluating the generalized hypergeometric functions like 3F4 and 2F3 is a task of comparable difficulty. I love HyperGeometricFunctions.jl, but that's a lot of pressure to put on that package, which at the very least in my experience is very slow when the besselk argument is small (which is unfortunately where the accurate derivatives matter the most). More importantly, though, the representation in Santander is hard in a bunch of edge cases. Like, when $\nu$ is an integer or near-integer, there are several problems, both with cancellations and in trig functions blowing up. If nu = 1 + 1e-8 or something that ostensibly exact equation might give literally zero digits of accuracy. The exact derivatives when nu + 1/2 is a whole integer are particularly gnarly and I've never seen them for any case besides nu=1/2.

Our project was enough of a hassle that we ended up writing a paper about it, and almost all the work was in handling the problems of $\nu$ being exactly or nearly an integer of half-integer. I don't think there's any way around a gnarly branching function to handle the derivatives in those cases. And if you look our timings (table one of the paper), it will probably be hard to come anywhere near those speeds at even comparable accuracy around those edge cases.

I've actually thought about asking the SpecialFunctions package folks if they'd be interested in some of our code being added to that package, but considering that we are a bit cavalier in giving up the last couple digits of accuracy I'm a bit concerned that it's not a great fit.

In any case, just posting here for your consideration. If somebody manages to implement them with exact expressions in a way that is tolerably fast and handles those edge cases, I'll be the first person to celebrate. In the mean time, though, I wouldn't be shocked if zygote compatibility was possible. I just really don't know enough to conjecture about how much of a project it would be.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

7 participants