Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy.linalg.multi_dot is O(2^N) in the number of matrices being multiplied #25051

Open
rohan-hitchcock opened this issue Nov 22, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@rohan-hitchcock
Copy link

Description

The component of jax.numpy.linalg.multi_dot which computes the optimal order of matrix multiplication is O(2^N) in the number of matrices being multiplied together. This is much worse performance than numpy.linalg.multi_dot, which is O(N^3). This shows up only when the function is being traced or is run outside of jit, since the ordering seems to be computed at trace time.

Example:

import time

import jax
import jax.numpy as jnp

import numpy as np

N = 12    # jax's multidot scales as O(2^N) while numpy's does not

WIDTH = 50 # multiply N matrices of shape (WIDTH, WIDTH)
SEED = 0 # for matrix entries

def get_matrices(key, widths: list[int]):
    shapes = zip(widths, widths[1:])
    keys = jax.random.split(key, num=len(widths) - 1)
    return [jax.random.normal(k, shape) for k, shape in zip(keys, shapes)]


  key = jax.random.key(SEED)

  widths = N * [WIDTH]
  # alt: 
  # widths = [38, 49, 26, 32, 29, 28, 49, 46, 41, 46, 49, 42]
  # 
  # alt:
  # key, key_widths = jax.random.split(key)
  # widths = jax.random.randint(key_widths, (N,), WIDTH // 2, WIDTH).tolist()

  xs = get_matrices(key, widths)

  # time jax's multi_dot
  jax_start = time.time()
  jnp.linalg.multi_dot(xs).block_until_ready()
  jax_time = time.time() - jax_start

  # time numpy's multi_dot
  xs = [np.asarray(x) for x in xs] 
  numpy_start = time.time()
  np.linalg.multi_dot(xs)
  numpy_time = time.time() - numpy_start

  print(f"{jax_time=} s")
  print(f"{numpy_time=} s")

I have more code for reproducing the issue here. If we plot the execution time and compare it to numpy we see it is exponential:
jax-vs-numpy

Other info

I copied numpy.linalg.multi_dot and made minimal changes to make it jittable, and the result seems to work well (see here). The jax version uses jnp.einsum with optimize=True, so maybe the error is in opt_einsum?

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35
jaxlib: 0.4.35
numpy: 2.1.3
python: 3.12.6 (main, Sep 6 2024, 19:03:47) [Clang 15.0.0 (clang-1500.3.9.4)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='4310L-212310-M.local', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:03:15 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T6000', machine='arm64')

@rohan-hitchcock rohan-hitchcock added the bug Something isn't working label Nov 22, 2024
@dfm
Copy link
Collaborator

dfm commented Nov 22, 2024

The time that you're seeing here is actually the compile (Edit: actually "tracing" time) time, not the runtime! Take a look at the JAX FAQ about microbenchmarks for more details.

If you update your benchmarking code as follows:

@jax.jit
def fun(xs):
  return jnp.linalg.multi_dot(xs)

fun(xs).block_until_ready()  # compile the function
jax_start = time.time()
fun(xs).block_until_ready()
jax_time = time.time() - jax_start
print(f"{jax_time=} s")
print(f"{numpy_time=} s")

this will print something like

jax_time=0.00041556358337402344 s
numpy_time=0.0005590915679931641 s

Hope this helps!

@dfm
Copy link
Collaborator

dfm commented Nov 22, 2024

Quick follow-up: @mattjj has pointed out that your point about opt_einsum (which I missed - sorry!) is a good one! If the tracing time is a significant bottleneck in your use case, perhaps it would be useful to (at least!) provide user control of the einsum optimization?

@mattjj
Copy link
Collaborator

mattjj commented Nov 22, 2024

We could also implement the matrix chain multiplication DP algorithm ourselves directly. (I’m assuming there’s no easy call into numpy or opt_einsum to compute it.)

We probably don’t want to stick with 2^N behavior by default, especially since this API really is about providing the N^2 behavior.

@dfm
Copy link
Collaborator

dfm commented Nov 22, 2024

Good point! It also seems possible that switching optimize='optimal' to optimize='dp' here:

jax/jax/_src/numpy/linalg.py

Lines 2121 to 2122 in 34a2f0c

return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload]
optimize='optimal', precision=precision)

would get us that behavior already, although I'm not very familiar with the specific algorithms.

@dfm
Copy link
Collaborator

dfm commented Nov 22, 2024

I want to ping @jakevdp here because he recently removed our own implementation of the DP algorithm in #21115, so he probably has opinions!

@rohan-hitchcock
Copy link
Author

Yes, I understand this is something that only shows up at trace time. The issue is that since compiling is exponential in the number of matrices it's unuseable for multiplying more than 15 or so matrices.

I copied numpy's implementation here and it doesn't have this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants