-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.numpy.linalg.multi_dot is O(2^N) in the number of matrices being multiplied #25051
Comments
The time that you're seeing here is actually the compile (Edit: actually "tracing" time) time, not the runtime! Take a look at the JAX FAQ about microbenchmarks for more details. If you update your benchmarking code as follows: @jax.jit
def fun(xs):
return jnp.linalg.multi_dot(xs)
fun(xs).block_until_ready() # compile the function
jax_start = time.time()
fun(xs).block_until_ready()
jax_time = time.time() - jax_start
print(f"{jax_time=} s")
print(f"{numpy_time=} s") this will print something like
Hope this helps! |
Quick follow-up: @mattjj has pointed out that your point about |
We could also implement the matrix chain multiplication DP algorithm ourselves directly. (I’m assuming there’s no easy call into numpy or opt_einsum to compute it.) We probably don’t want to stick with 2^N behavior by default, especially since this API really is about providing the N^2 behavior. |
Good point! It also seems possible that switching Lines 2121 to 2122 in 34a2f0c
would get us that behavior already, although I'm not very familiar with the specific algorithms. |
Yes, I understand this is something that only shows up at trace time. The issue is that since compiling is exponential in the number of matrices it's unuseable for multiplying more than 15 or so matrices. I copied numpy's implementation here and it doesn't have this issue. |
Description
The component of
jax.numpy.linalg.multi_dot
which computes the optimal order of matrix multiplication is O(2^N) in the number of matrices being multiplied together. This is much worse performance thannumpy.linalg.multi_dot
, which is O(N^3). This shows up only when the function is being traced or is run outside of jit, since the ordering seems to be computed at trace time.Example:
I have more code for reproducing the issue here. If we plot the execution time and compare it to numpy we see it is exponential:
Other info
I copied
numpy.linalg.multi_dot
and made minimal changes to make it jittable, and the result seems to work well (see here). The jax version usesjnp.einsum
withoptimize=True
, so maybe the error is inopt_einsum
?System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35
jaxlib: 0.4.35
numpy: 2.1.3
python: 3.12.6 (main, Sep 6 2024, 19:03:47) [Clang 15.0.0 (clang-1500.3.9.4)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='4310L-212310-M.local', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:03:15 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T6000', machine='arm64')
The text was updated successfully, but these errors were encountered: