Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU bug] Memcpy local p2p leads to numerical issues (mem access problems). #19555

Open
yliu120 opened this issue Nov 20, 2024 · 1 comment
Open
Assignees

Comments

@yliu120
Copy link
Contributor

yliu120 commented Nov 20, 2024

Here is the repro.

import os

XLA_FLAGS = [
    "--xla_dump_to=/tmp/hlos",
    "--xla_gpu_enable_latency_hiding_scheduler=true",
    "--xla_gpu_enable_triton_gemm=false",
    "--xla_gpu_graph_level=0",
    "--xla_disable_hlo_passes=rematerialization",
    #  Buggy.
    # "--xla_gpu_use_memcpy_local_p2p=true",
    "--xla_gpu_enable_pipelined_all_gather=true",
    "--xla_gpu_enable_while_loop_double_buffering=true",
]
os.environ["XLA_FLAGS"] = " ".join(XLA_FLAGS)

from functools import partial

import jax
import jax.numpy as jnp

from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding

import numpy as np


def matmul(a, b, mesh=None):
    @partial(
        shard_map,
        mesh=mesh,
        in_specs=(P("i", None), P(None, "i")),
        out_specs=P(None, "i"),
        check_rep=False,
    )
    def _all_gather_matmul(a, b):
        a_full = jax.lax.all_gather(a, "i", tiled=True)
        return a_full @ b

    return _all_gather_matmul(a, b)


def collective_matmul_unroll(a, b, c, mesh=None):
    """c = a @ b.

    In this version, we use pre-allocated buffer to benchmark CM without invoking
    memory allocation through broadcast inside the program.
    """
    num_devices = mesh.size
    perm = [(i, (i + 1) % num_devices) for i in range(num_devices)]

    @partial(
        shard_map,
        mesh=mesh,
        in_specs=(P("i", None), P(None, "i"), P(None, "i")),
        out_specs=P(None, "i"),
        check_rep=False,
    )
    def _collective_matmul(a, b, c):
        idx = jax.lax.axis_index("i")
        for i in range(num_devices):
            c_part = a @ b
            # For each axis shard, get the ith shard prior to it.
            #   e.g. for shard 0, if we rotate twice, the 2nd shard ahead of it is N - 2.
            prev_shard_index = (idx - i + num_devices) % num_devices
            c = jax.lax.dynamic_update_slice(c, c_part, (prev_shard_index * a.shape[0], 0))
            if i != num_devices - 1:
                a = jax.lax.ppermute(a, "i", perm)

        return c

    return _collective_matmul(a, b, c)


def collective_matmul_stacked(a, b, c, mesh=None):
    num_devices = mesh.size
    perm = [(i, (i + 1) % num_devices) for i in range(num_devices)]

    @partial(
        shard_map,
        mesh=mesh,
        in_specs=(P("i", None), P(None, "i"), P(None, "i")),
        out_specs=P(None, "i"),
        check_rep=False,
    )
    def _collective_matmul(a, b, c):
        # A temporary buffer for storing the next shard.
        a_next = jnp.zeros_like(a)
        idx = jax.lax.axis_index("i")

        def _compute_and_update(a, b, c, i):
            """Computes matmul between the sharded a and b and writes to c."""
            prev_shard_index = (idx - i + num_devices) % num_devices
            c = jax.lax.dynamic_update_slice(c, a @ b, (prev_shard_index * a.shape[0], 0))
            return c

        def _body_fn(i, inputs):
            # Double buffering.
            a_curr, a_next, c = inputs
            # Permute a_curr to get a_next
            a_next = jax.lax.ppermute(a_curr, "i", perm)
            c = _compute_and_update(a_curr, b, c, i)
            return a_next, a_curr, c

        a_last, _, c = jax.lax.fori_loop(0, num_devices - 1, _body_fn, (a, a_next, c))
        return _compute_and_update(a_last, b, c, num_devices - 1)

    return _collective_matmul(a, b, c)


def main():
    devices = jax.devices()[:4]
    mesh = jax.make_mesh((len(devices),), ("i"), devices=devices)
    dtype = jnp.bfloat16

    # Prepare data.
    array_size = 2048
    key = jax.random.PRNGKey(0)
    a = jax.device_put(
        jax.random.uniform(
            key, (array_size * 8, array_size * 6), dtype=dtype, minval=0.0, maxval=0.1
        ),
        NamedSharding(mesh, P("i", None)),
    )
    b = jax.device_put(
        jax.random.uniform(
            key, (array_size * 6, array_size * 16), dtype=dtype, minval=0.0, maxval=0.1
        ),
        NamedSharding(mesh, P(None, "i")),
    )
    # Result buffer.
    c1 = jax.device_put(
        jnp.zeros((array_size * 8, array_size * 16), dtype=dtype),
        NamedSharding(mesh, P(None, "i")),
    )
    c2 = jax.device_put(
        jnp.zeros((array_size * 8, array_size * 16), dtype=dtype),
        NamedSharding(mesh, P(None, "i")),
    )

    # AOT compile.
    # Adds the all-gathered version for comparing non-CM sharded matmul with the ground truth.
    matmul_compiled = jax.jit(matmul, static_argnums=(2,)).lower(a, b, mesh=mesh).compile()
    cm_unroll_compiled = (
        jax.jit(collective_matmul_unroll, static_argnums=(3,), donate_argnums=(2,))
        .lower(a, b, c1, mesh=mesh)
        .compile()
    )
    cm_stacked_compiled = (
        jax.jit(collective_matmul_stacked, static_argnums=(3,), donate_argnums=(2,))
        .lower(a, b, c2, mesh=mesh)
        .compile()
    )

    # Profile each case.
    jax.profiler.start_trace("/tmp/tensorboard")
    ag_matmul_results = [matmul_compiled(a, b) for _ in range(10)]
    jax.block_until_ready(ag_matmul_results)

    for _ in range(10):
        c1 = cm_unroll_compiled(a, b, c1)
    jax.block_until_ready(c1)

    for _ in range(10):
        c2 = cm_stacked_compiled(a, b, c2)
    jax.block_until_ready(c2)
    jax.profiler.stop_trace()

    # Verify numerical correctness.
    ref = a @ b
    try:
        np.testing.assert_allclose(
            ag_matmul_results[0].astype(jnp.float32), ref.astype(jnp.float32)
        )
    except AssertionError as e:
        print(f"Matmul with all-gather gives numerical difference: \n{str(e)}")

    try:
        # Silence the numerical differences invoked by different cublas kernels.
        np.testing.assert_allclose(c1.astype(jnp.float32), ref.astype(jnp.float32), rtol=0.01)
    except AssertionError as e:
        print(f"Collective matmul (unrolled) gives numerical difference: \n{str(e)}")

    try:
        np.testing.assert_allclose(c2.astype(jnp.float32), ref.astype(jnp.float32), rtol=0.01)
    except AssertionError as e:
        print(f"Collective matmul (stacked) gives gives numerical difference: \n{str(e)}")


if __name__ == "__main__":
    main()
@abhinavgoel95
Copy link

cc @Tixxx please have a look at this.

@Tixxx Tixxx self-assigned this Nov 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants