Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to get the tree size of a jaxpr #24995

Open
carlosgmartin opened this issue Nov 20, 2024 · 7 comments
Open

Add a way to get the tree size of a jaxpr #24995

carlosgmartin opened this issue Nov 20, 2024 · 7 comments
Labels
enhancement New feature or request

Comments

@carlosgmartin
Copy link
Contributor

Feature request: Add a way to get the size of a Jaxpr and ClosedJaxpr, in terms of the size of the tree itself (not the length of its string representation).

This could be done by adding a __len__ dunder to these classes, or another method.

@carlosgmartin carlosgmartin added the enhancement New feature or request label Nov 20, 2024
@superbobry
Copy link
Collaborator

Can you add a few details about what you are trying to do, please?

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 20, 2024

Does len(jaxpr.eqns) do what you want?

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Nov 20, 2024

@superbobry I was looking for some rough measure of jaxpr "size" or "complexity"/"simplicity" in order to compare different potential implementations of a JAX function, alongside cost analysis. (See this comment, for example.)

@jakevdp No.

import jax
from jax import numpy as jnp
args = jnp.zeros([2, 3, 5]), jnp.zeros([2, 7, 5], int), 1., 1, False
jaxpr = jax.make_jaxpr(jnp.put_along_axis, [3, 4])(*args)
print(len(jaxpr.eqns))  # 1

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 20, 2024

@jakevdp No.

Maybe if you disable JIT it will give you something closer to what you're after?

from jax import numpy as jnp
args = jnp.zeros([2, 3, 5]), jnp.zeros([2, 7, 5], int), 1., 1, False
with jax.disable_jit():
  jaxpr = jax.make_jaxpr(jnp.put_along_axis, [3, 4])(*args)
print(len(jaxpr.eqns))  # 24

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 20, 2024

Side note though: jaxpr size is not a very good proxy for complexity in general: for example, jaxprs do not do any dead code elimination. Given this, compiler-based measures like cost analysis will probably be more robust.

@carlosgmartin
Copy link
Contributor Author

@jakevdp That's closer to what I'm looking for!

The reason I brought this up is because of your comment:

I think we should optimize for a simple jaxpr here, not optimize for runtime on a particular device (which may or may not reflect runtime on other device types, or with other array sizes).

which could also be good advice for other functions beyond put_along_axis.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 20, 2024

Oh OK – I don't think "simple jaxpr" is necessarily the right metric in general, but in that specific case we were talking about implementing put_along_axis, which I believe could be implemented in most cases via a single scatter call. In that particular case, we should aim to write that single scatter call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants