You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi @martinResearch and thanks for creating this package, which I stumbled upon while reading jax-ml/jax#1032. This issue is about understanding it a little better.
From what I can tell, known sparsity patterns are most useful in conjunction with coloring algorithms, because they reduce the number of forward (or reverse) passes needed to compute a Jacobian. Typically, the Jacobian of a function $f : R^n \to R^m$ would require $n$ forward passes, one for each of the Jacobian-vector products associated with the basis vectors of $R^n$. Grouping basis vectors together is what the coloring step is all about, and it can reduce the number of forward passes from $O(n)$ to $O(1)$ in the best cases. See https://epubs.siam.org/doi/10.1137/S0036144504444711 for more details, or the Example 5.7 from https://tomopt.com/docs/TOMLAB_MAD.pdf#page26.
As you stated in jax-ml/jax#1032 (comment), your library does not rely on this paradigm. When you talk of a "single forward pass using matrix-matrix products at each step of the forward computation", is it correct that you still end up computing the JVP with every single basis vector? In other words, while the runtime may be low in practice thanks to vectorization and efficient sparse operations, the theoretical complexity remains $O(n)$?
Hi @gdalle ,
"is it correct that you still end up computing the JVP with every single basis vector" in some sense yes: I start with the identity matrix for the derivatives of the input vector (
) and then multiply that derivates matrix with the Jacobian of each operation in the chain using the forward chain rule . Each column of the initial identity matrix can be interpreted as a basis vector. Even if I were to use dense matrices to represent the derivatives, also the complexity would be O(n) this would differs slightly from an implementation of Forward AD that would call n times the end-to-end function because the code is executed only once.
"the theoretical complexity remains O(n)?", It depends on how you define "theoretical" and were we draw the line between theoretical and practice.
If a function you want to differentiate has a Jacobian whose sparsity is s, I would expect the complexity of the method that uses the coloring approach to be O((1-s)*n) with (1-s)*n an approximation of the number of groups and passes needed, is that correct?
The complexity of the approach implemented here would also theoretically be O((1-S)*n) if we assume that all the intermediate derivatives in the forward derivates chain rule have also sparsity s or greater then s. I believe that in general this it a reasonable assumption because it is rare to have denser intermediate derivatives the the end derivates because that would require to have derivatives values that cancel each other out in some step to increase the sparsity in following steps of the chain. If we can theoretically prove for a particular class of problems that the sparsity of the intermediate derivatives is at least s then the speedup is not only practical but also becomes theoretical for that specific problem.
Hi @martinResearch and thanks for creating this package, which I stumbled upon while reading jax-ml/jax#1032. This issue is about understanding it a little better.
From what I can tell, known sparsity patterns are most useful in conjunction with coloring algorithms, because they reduce the number of forward (or reverse) passes needed to compute a Jacobian. Typically, the Jacobian of a function$f : R^n \to R^m$ would require $n$ forward passes, one for each of the Jacobian-vector products associated with the basis vectors of $R^n$ . Grouping basis vectors together is what the coloring step is all about, and it can reduce the number of forward passes from $O(n)$ to $O(1)$ in the best cases. See https://epubs.siam.org/doi/10.1137/S0036144504444711 for more details, or the Example 5.7 from https://tomopt.com/docs/TOMLAB_MAD.pdf#page26.
As you stated in jax-ml/jax#1032 (comment), your library does not rely on this paradigm. When you talk of a "single forward pass using matrix-matrix products at each step of the forward computation", is it correct that you still end up computing the JVP with every single basis vector? In other words, while the runtime may be low in practice thanks to vectorization and efficient sparse operations, the theoretical complexity remains$O(n)$ ?
Thanks in advance for your response
ping @adrhill
The text was updated successfully, but these errors were encountered: