diff --git a/prototype_source/flight_recorder_tutorial.rst b/prototype_source/flight_recorder_tutorial.rst index df1abbf9d6..ab83ebe027 100644 --- a/prototype_source/flight_recorder_tutorial.rst +++ b/prototype_source/flight_recorder_tutorial.rst @@ -202,6 +202,100 @@ Caveat: tabulate module is needed, so you might need pip install it first. python fr_trace.py -j [--selected-ranks i j k ...] [--pg-filters tp dp] torchfrtrace -j [--selected-ranks i j k ...] [--pg-filters 0 2] +An End-to-End Example +------------------------------------ +To demonstrate the use of Flight Recorder, we will use a small program where we induce mismatched collectives. +In this example, ``rank0`` is programmed to do an additional collective. +The Flight Recorder dump files are saved to the ``/tmp`` directory. +For demonstration purposes, we named this program ``crash.py``. + +.. note:: + Please note that this is a simplified example. In real-world scenarios, the process would involve more + complexities. + +.. code:: python + :caption: A crashing example + + import torch + import torch.distributed as dist + import os + from datetime import timedelta + + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + assert world_size <= 8, "world size must be less than or equal to 8" + os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = "/tmp/trace_" + os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1" + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000" + device = torch.device(f"cuda:{local_rank}") + print(f"{local_rank=} {world_size=} master addr: {os.environ['MASTER_ADDR']} master port: {os.environ['MASTER_PORT']} {device=}") + + # Initialize the process group with a small timeout so that jobs fail quickly + dist.init_process_group("nccl", world_size=world_size, rank=local_rank, timeout=timedelta(seconds=1)) + + a = torch.full((3, 4), float(local_rank), device=device) + # Write some collectives to populate Flight Recorder data + for i in range(2): + print(f"calling allreduce on {local_rank=}") + f = dist.all_reduce(a) + + # rank0 is doing an additional collective + if local_rank == 0: + print("rank0 is doing an allreduce on tensor b, but other ranks forgot") + b = torch.full((4,5), float(local_rank), device=device) + f = dist.all_reduce(b) + + for i in range(2): + print(f"calling allreduce on {local_rank=}") + f = dist.all_reduce(a) + + torch.cuda.synchronize(device=device) + print(f"{local_rank=} exiting") + + +To run this program, use ``torchrun``: + + +.. code:: python + + torchrun --nnodes=1 --nproc_per_node=2 crash.py + +You should see two files in the ``/tmp`` directory: + +.. code:: bash + + $ls /tmp/trace* + # Expected output + /tmp/trace_0 /tmp/trace_1 + +Finally, to analyze these two files, we use the ``torchfrtrace`` command: + +.. code:: bash + + torchfrtrace --prefix "trace_" /tmp/ + +The output from the trace command is meant to be human-readable. It includes information about the +set of collectives that caused a failure. +The output for the command above is shown below. +We can clearly see that rank 1 did not join the "all_reduce" collective. + +.. code-block:: bash + $torchfrtrace --prefix "trace_" /tmp/ + Not all ranks joining collective 5 at entry 4 + group info: 0:default_pg + collective: nccl:all_reduce + missing ranks: {1} + input sizes: [[3, 4]] + output sizes: [[3, 4]] + expected ranks: 2 + collective state: scheduled + collective stack trace: + all_reduce at /home/cpio/local/pytorch/torch/distributed/distributed_c10d.py:2696 + wrapper at /home/cpio/local/pytorch/torch/distributed/c10d_logger.py:83 + at /home/cpio/test/crash.py:44 + + + Conclusion ---------- In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.