Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc] add small example to flight recorder tutorial #3163

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions prototype_source/flight_recorder_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,100 @@ Caveat: tabulate module is needed, so you might need pip install it first.
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
An End-to-End Example
------------------------------------
To demonstrate the use of Flight Recorder, we will use a small program where we induce mismatched collectives.
In this example, ``rank0`` is programmed to do an additional collective.
The Flight Recorder dump files are saved to the ``/tmp`` directory.
For demonstration purposes, we named this program ``crash.py``.

.. note::
Please note that this is a simplified example. In real-world scenarios, the process would involve more
complexities.

.. code:: python
:caption: A crashing example
import torch
import torch.distributed as dist
import os
from datetime import timedelta
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size <= 8, "world size must be less than or equal to 8"
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = "/tmp/trace_"
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"
device = torch.device(f"cuda:{local_rank}")
print(f"{local_rank=} {world_size=} master addr: {os.environ['MASTER_ADDR']} master port: {os.environ['MASTER_PORT']} {device=}")
# Initialize the process group with a small timeout so that jobs fail quickly
dist.init_process_group("nccl", world_size=world_size, rank=local_rank, timeout=timedelta(seconds=1))
a = torch.full((3, 4), float(local_rank), device=device)
# Write some collectives to populate Flight Recorder data
for i in range(2):
print(f"calling allreduce on {local_rank=}")
f = dist.all_reduce(a)
# rank0 is doing an additional collective
if local_rank == 0:
print("rank0 is doing an allreduce on tensor b, but other ranks forgot")
b = torch.full((4,5), float(local_rank), device=device)
f = dist.all_reduce(b)
for i in range(2):
print(f"calling allreduce on {local_rank=}")
f = dist.all_reduce(a)
torch.cuda.synchronize(device=device)
print(f"{local_rank=} exiting")
To run this program, use ``torchrun``:


.. code:: python
torchrun --nnodes=1 --nproc_per_node=2 crash.py
You should see two files in the ``/tmp`` directory:

.. code:: bash
$ls /tmp/trace*
# Expected output
/tmp/trace_0 /tmp/trace_1
Finally, to analyze these two files, we use the ``torchfrtrace`` command:

.. code:: bash
torchfrtrace --prefix "trace_" /tmp/
The output from the trace command is meant to be human-readable. It includes information about the
set of collectives that caused a failure.
The output for the command above is shown below.
We can clearly see that rank 1 did not join the "all_reduce" collective.

.. code-block:: bash
$torchfrtrace --prefix "trace_" /tmp/
Not all ranks joining collective 5 at entry 4
group info: 0:default_pg
collective: nccl:all_reduce
missing ranks: {1}
input sizes: [[3, 4]]
output sizes: [[3, 4]]
expected ranks: 2
collective state: scheduled
collective stack trace:
all_reduce at /home/cpio/local/pytorch/torch/distributed/distributed_c10d.py:2696
wrapper at /home/cpio/local/pytorch/torch/distributed/c10d_logger.py:83
<module> at /home/cpio/test/crash.py:44
Conclusion
----------
In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.
Expand Down
Loading