We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
from functools import partial import jax import jax.numpy as jnp from jax.experimental import jet from folx import forward_laplacian import enzyme_ad.jax as ejax from jax._src.lib import xla_client def todotgraph(x): return xla_client._xla.hlo_module_to_dot_graph(xla_client._xla.hlo_module_from_text(x)) def save_stablehlo(fn, filename, *sample_inputs): lowered = fn.lower(*sample_inputs) lowered_compiled = lowered.compile() with open(f"compiled_{filename}.dot",'w') as f: f.write(todotgraph(lowered_compiled.as_text())) with open(filename, 'w') as f: f.write(str(lowered.compiler_ir("stablehlo"))) def f(ws, wo, x): for w in ws: x = jax.lax.exp(x @ w) return jnp.reshape(x @ wo, ()) @jax.jit @ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ejax.hlo_opts())) @partial(jax.vmap, in_axes=(None, None, 0)) def laplacian_1(ws, wo, x): fun = partial(f, ws, wo) @jax.vmap def hvv(v): return jet.jet(fun, (x,), ((v, jnp.zeros_like(x)),))[1][1] return jnp.sum(hvv(jnp.eye(x.shape[0], dtype=x.dtype))) @jax.jit @ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ejax.hlo_opts())) @partial(jax.vmap, in_axes=(None, None, 0)) def laplacian_2(ws, wo, x): fun = partial(f, ws, wo) in_tangents = jnp.eye(x.shape[0], dtype=x.dtype) pushfwd = partial(jax.jvp, jax.grad(fun), (x,)) _, hessian = jax.vmap(pushfwd, out_axes=(None, 0))((in_tangents,)) return jnp.trace(hessian) @jax.jit @ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ejax.hlo_opts())) @partial(jax.vmap, in_axes=(None, None, 0)) def laplacian_3(ws, wo, x): fun = partial(f, ws, wo) return jnp.trace(jax.hessian(fun)(x)) @jax.jit @ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ejax.hlo_opts())) @partial(jax.vmap, in_axes=(None, None, 0)) def laplacian_4(ws, wo, x): fun = partial(f, ws, wo) fwd_f = forward_laplacian(fun) result = fwd_f(x) return result.laplacian def timer(f, n_iter=5): from time import time f() # compile t = time() for _ in range(n_iter): f() print((time() - t) / n_iter) def main(): # Setup dimensions and data d = 16 ws = [jnp.zeros((d, d)) for _ in range(2)] wo = jnp.zeros((d, 1)) x = jnp.zeros((512, d)) # Save StableHLO IR for each implementation print("Saving StableHLO IR to files...") save_stablehlo(laplacian_1, "laplacian1_stablehlo.txt", ws, wo, x) save_stablehlo(laplacian_2, "laplacian2_stablehlo.txt", ws, wo, x) save_stablehlo(laplacian_3, "laplacian3_stablehlo.txt", ws, wo, x) save_stablehlo(laplacian_4, "laplacian4_stablehlo.txt", ws, wo, x) # Run timing benchmarks print('\nBenchmarking different implementations:') print('Taylor method:') timer(lambda: jax.block_until_ready(laplacian_1(ws, wo, x))) print('JVP method:') timer(lambda: jax.block_until_ready(laplacian_2(ws, wo, x))) print('Hessian method:') timer(lambda: jax.block_until_ready(laplacian_3(ws, wo, x))) print('Forward Laplacian method:') timer(lambda: jax.block_until_ready(laplacian_4(ws, wo, x))) if __name__ == "__main__": main()
The text was updated successfully, but these errors were encountered:
No branches or pull requests
The text was updated successfully, but these errors were encountered: