We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Here is the repro.
import os XLA_FLAGS = [ "--xla_dump_to=/tmp/hlos", "--xla_gpu_enable_latency_hiding_scheduler=true", "--xla_gpu_enable_triton_gemm=false", "--xla_gpu_graph_level=0", "--xla_disable_hlo_passes=rematerialization", # Buggy. # "--xla_gpu_use_memcpy_local_p2p=true", "--xla_gpu_enable_pipelined_all_gather=true", "--xla_gpu_enable_while_loop_double_buffering=true", ] os.environ["XLA_FLAGS"] = " ".join(XLA_FLAGS) from functools import partial import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding import numpy as np def matmul(a, b, mesh=None): @partial( shard_map, mesh=mesh, in_specs=(P("i", None), P(None, "i")), out_specs=P(None, "i"), check_rep=False, ) def _all_gather_matmul(a, b): a_full = jax.lax.all_gather(a, "i", tiled=True) return a_full @ b return _all_gather_matmul(a, b) def collective_matmul_unroll(a, b, c, mesh=None): """c = a @ b. In this version, we use pre-allocated buffer to benchmark CM without invoking memory allocation through broadcast inside the program. """ num_devices = mesh.size perm = [(i, (i + 1) % num_devices) for i in range(num_devices)] @partial( shard_map, mesh=mesh, in_specs=(P("i", None), P(None, "i"), P(None, "i")), out_specs=P(None, "i"), check_rep=False, ) def _collective_matmul(a, b, c): idx = jax.lax.axis_index("i") for i in range(num_devices): c_part = a @ b # For each axis shard, get the ith shard prior to it. # e.g. for shard 0, if we rotate twice, the 2nd shard ahead of it is N - 2. prev_shard_index = (idx - i + num_devices) % num_devices c = jax.lax.dynamic_update_slice(c, c_part, (prev_shard_index * a.shape[0], 0)) if i != num_devices - 1: a = jax.lax.ppermute(a, "i", perm) return c return _collective_matmul(a, b, c) def collective_matmul_stacked(a, b, c, mesh=None): num_devices = mesh.size perm = [(i, (i + 1) % num_devices) for i in range(num_devices)] @partial( shard_map, mesh=mesh, in_specs=(P("i", None), P(None, "i"), P(None, "i")), out_specs=P(None, "i"), check_rep=False, ) def _collective_matmul(a, b, c): # A temporary buffer for storing the next shard. a_next = jnp.zeros_like(a) idx = jax.lax.axis_index("i") def _compute_and_update(a, b, c, i): """Computes matmul between the sharded a and b and writes to c.""" prev_shard_index = (idx - i + num_devices) % num_devices c = jax.lax.dynamic_update_slice(c, a @ b, (prev_shard_index * a.shape[0], 0)) return c def _body_fn(i, inputs): # Double buffering. a_curr, a_next, c = inputs # Permute a_curr to get a_next a_next = jax.lax.ppermute(a_curr, "i", perm) c = _compute_and_update(a_curr, b, c, i) return a_next, a_curr, c a_last, _, c = jax.lax.fori_loop(0, num_devices - 1, _body_fn, (a, a_next, c)) return _compute_and_update(a_last, b, c, num_devices - 1) return _collective_matmul(a, b, c) def main(): devices = jax.devices()[:4] mesh = jax.make_mesh((len(devices),), ("i"), devices=devices) dtype = jnp.bfloat16 # Prepare data. array_size = 2048 key = jax.random.PRNGKey(0) a = jax.device_put( jax.random.uniform( key, (array_size * 8, array_size * 6), dtype=dtype, minval=0.0, maxval=0.1 ), NamedSharding(mesh, P("i", None)), ) b = jax.device_put( jax.random.uniform( key, (array_size * 6, array_size * 16), dtype=dtype, minval=0.0, maxval=0.1 ), NamedSharding(mesh, P(None, "i")), ) # Result buffer. c1 = jax.device_put( jnp.zeros((array_size * 8, array_size * 16), dtype=dtype), NamedSharding(mesh, P(None, "i")), ) c2 = jax.device_put( jnp.zeros((array_size * 8, array_size * 16), dtype=dtype), NamedSharding(mesh, P(None, "i")), ) # AOT compile. # Adds the all-gathered version for comparing non-CM sharded matmul with the ground truth. matmul_compiled = jax.jit(matmul, static_argnums=(2,)).lower(a, b, mesh=mesh).compile() cm_unroll_compiled = ( jax.jit(collective_matmul_unroll, static_argnums=(3,), donate_argnums=(2,)) .lower(a, b, c1, mesh=mesh) .compile() ) cm_stacked_compiled = ( jax.jit(collective_matmul_stacked, static_argnums=(3,), donate_argnums=(2,)) .lower(a, b, c2, mesh=mesh) .compile() ) # Profile each case. jax.profiler.start_trace("/tmp/tensorboard") ag_matmul_results = [matmul_compiled(a, b) for _ in range(10)] jax.block_until_ready(ag_matmul_results) for _ in range(10): c1 = cm_unroll_compiled(a, b, c1) jax.block_until_ready(c1) for _ in range(10): c2 = cm_stacked_compiled(a, b, c2) jax.block_until_ready(c2) jax.profiler.stop_trace() # Verify numerical correctness. ref = a @ b try: np.testing.assert_allclose( ag_matmul_results[0].astype(jnp.float32), ref.astype(jnp.float32) ) except AssertionError as e: print(f"Matmul with all-gather gives numerical difference: \n{str(e)}") try: # Silence the numerical differences invoked by different cublas kernels. np.testing.assert_allclose(c1.astype(jnp.float32), ref.astype(jnp.float32), rtol=0.01) except AssertionError as e: print(f"Collective matmul (unrolled) gives numerical difference: \n{str(e)}") try: np.testing.assert_allclose(c2.astype(jnp.float32), ref.astype(jnp.float32), rtol=0.01) except AssertionError as e: print(f"Collective matmul (stacked) gives gives numerical difference: \n{str(e)}") if __name__ == "__main__": main()
The text was updated successfully, but these errors were encountered:
cc @Tixxx please have a look at this.
Sorry, something went wrong.
Tixxx
No branches or pull requests
Here is the repro.
The text was updated successfully, but these errors were encountered: