-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a way to get the tree size of a jaxpr #24995
Comments
Can you add a few details about what you are trying to do, please? |
Does |
@superbobry I was looking for some rough measure of jaxpr "size" or "complexity"/"simplicity" in order to compare different potential implementations of a JAX function, alongside cost analysis. (See this comment, for example.) @jakevdp No. import jax
from jax import numpy as jnp
args = jnp.zeros([2, 3, 5]), jnp.zeros([2, 7, 5], int), 1., 1, False
jaxpr = jax.make_jaxpr(jnp.put_along_axis, [3, 4])(*args)
print(len(jaxpr.eqns)) # 1 |
Maybe if you disable JIT it will give you something closer to what you're after? from jax import numpy as jnp
args = jnp.zeros([2, 3, 5]), jnp.zeros([2, 7, 5], int), 1., 1, False
with jax.disable_jit():
jaxpr = jax.make_jaxpr(jnp.put_along_axis, [3, 4])(*args)
print(len(jaxpr.eqns)) # 24 |
Side note though: jaxpr size is not a very good proxy for complexity in general: for example, jaxprs do not do any dead code elimination. Given this, compiler-based measures like cost analysis will probably be more robust. |
@jakevdp That's closer to what I'm looking for! The reason I brought this up is because of your comment:
which could also be good advice for other functions beyond |
Oh OK – I don't think "simple jaxpr" is necessarily the right metric in general, but in that specific case we were talking about implementing |
Feature request: Add a way to get the size of a
Jaxpr
andClosedJaxpr
, in terms of the size of the tree itself (not the length of its string representation).This could be done by adding a
__len__
dunder to these classes, or another method.The text was updated successfully, but these errors were encountered: