Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An enzyme optimization benchmark #180

Open
vimarsh6739 opened this issue Dec 5, 2024 · 0 comments
Open

An enzyme optimization benchmark #180

vimarsh6739 opened this issue Dec 5, 2024 · 0 comments

Comments

@vimarsh6739
Copy link
Member

from functools import partial
import jax
import jax.numpy as jnp
from jax.experimental import jet
from folx import forward_laplacian
import enzyme_ad.jax as ejax

from jax._src.lib import xla_client

def todotgraph(x):
   return xla_client._xla.hlo_module_to_dot_graph(xla_client._xla.hlo_module_from_text(x))

def save_stablehlo(fn, filename, *sample_inputs):
    lowered = fn.lower(*sample_inputs)
    lowered_compiled = lowered.compile()
    with open(f"compiled_{filename}.dot",'w') as f:
        f.write(todotgraph(lowered_compiled.as_text()))
    with open(filename, 'w') as f:
        f.write(str(lowered.compiler_ir("stablehlo")))

def f(ws, wo, x):
    for w in ws:
        x = jax.lax.exp(x @ w)
    return jnp.reshape(x @ wo, ())

@jax.jit
@ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ejax.hlo_opts()))
@partial(jax.vmap, in_axes=(None, None, 0))
def laplacian_1(ws, wo, x):
    fun = partial(f, ws, wo)
    @jax.vmap
    def hvv(v):
        return jet.jet(fun, (x,), ((v, jnp.zeros_like(x)),))[1][1]
    return jnp.sum(hvv(jnp.eye(x.shape[0], dtype=x.dtype)))

@jax.jit
@ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ejax.hlo_opts()))
@partial(jax.vmap, in_axes=(None, None, 0))
def laplacian_2(ws, wo, x):
    fun = partial(f, ws, wo)
    in_tangents = jnp.eye(x.shape[0], dtype=x.dtype)
    pushfwd = partial(jax.jvp, jax.grad(fun), (x,))
    _, hessian = jax.vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
    return jnp.trace(hessian)

@jax.jit
@ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ejax.hlo_opts()))
@partial(jax.vmap, in_axes=(None, None, 0))
def laplacian_3(ws, wo, x):
    fun = partial(f, ws, wo)
    return jnp.trace(jax.hessian(fun)(x))

@jax.jit
@ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ejax.hlo_opts()))
@partial(jax.vmap, in_axes=(None, None, 0))
def laplacian_4(ws, wo, x):
    fun = partial(f, ws, wo)
    fwd_f = forward_laplacian(fun)
    result = fwd_f(x)
    return result.laplacian

def timer(f, n_iter=5):
    from time import time
    f()  # compile
    t = time()
    for _ in range(n_iter):
        f()
    print((time() - t) / n_iter)

def main():
    # Setup dimensions and data
    d = 16
    ws = [jnp.zeros((d, d)) for _ in range(2)]
    wo = jnp.zeros((d, 1))
    x = jnp.zeros((512, d))

    # Save StableHLO IR for each implementation
    print("Saving StableHLO IR to files...")
    save_stablehlo(laplacian_1, "laplacian1_stablehlo.txt", ws, wo, x)
    save_stablehlo(laplacian_2, "laplacian2_stablehlo.txt", ws, wo, x)
    save_stablehlo(laplacian_3, "laplacian3_stablehlo.txt", ws, wo, x)
    save_stablehlo(laplacian_4, "laplacian4_stablehlo.txt", ws, wo, x)

    # Run timing benchmarks
    print('\nBenchmarking different implementations:')
    print('Taylor method:')
    timer(lambda: jax.block_until_ready(laplacian_1(ws, wo, x)))
    print('JVP method:')
    timer(lambda: jax.block_until_ready(laplacian_2(ws, wo, x)))
    print('Hessian method:')
    timer(lambda: jax.block_until_ready(laplacian_3(ws, wo, x)))
    print('Forward Laplacian method:')
    timer(lambda: jax.block_until_ready(laplacian_4(ws, wo, x)))

if __name__ == "__main__":
    main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant