Skip to content

Commit

Permalink
More FSDP optimizations (#10)
Browse files Browse the repository at this point in the history
* profile training loop

* add `--trace-output` arg

* skip norm for empty grads

* expand user

* debug logging

* post-backward final hook

* better logging

* mark backward exec order finalized for children

* fix lint
  • Loading branch information
epwalsh authored Apr 17, 2024
1 parent f235164 commit 2cfd3a7
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 46 deletions.
80 changes: 58 additions & 22 deletions src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""

import argparse
import contextlib
import logging
import os
import time
from pathlib import Path
from typing import Literal, Optional
Expand All @@ -32,6 +32,8 @@ def main(
save_path: Optional[str] = None,
load_path: Optional[str] = None,
mixed_precision: bool = True,
profile: bool = False,
trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz",
**kwargs,
):
model, optim, dataloader = build_components(
Expand All @@ -56,33 +58,56 @@ def main(
print_rank0(f"Saving checkpoint to {checkpoint_dir}...")
save_model_and_optim_state(checkpoint_dir, model, optim)

profiler = contextlib.nullcontext()
if profile:
from torch.profiler import ProfilerActivity, schedule

def on_trace_ready(p):
trace_path = Path(trace_output).expanduser()
trace_path.parent.mkdir(exist_ok=True, parents=True)
p.export_chrome_trace(str(trace_path))
print_rank0(f"Tracing complete, saved to '{trace_path}'")

profiler = torch.profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=False,
profile_memory=False,
with_stack=True,
schedule=schedule(wait=1, warmup=5, active=3, repeat=1),
on_trace_ready=on_trace_ready,
)

print_rank0("Starting training...")
for i, batch in enumerate(iter(dataloader)):
log.debug("Batch: %s", batch)
batch_start = time.monotonic()
with profiler as p:
for i, batch in enumerate(iter(dataloader)):
log.debug("Batch: %s", batch)
batch_start = time.monotonic()

# Zero-gradients.
optim.zero_grad()
# Zero-gradients.
optim.zero_grad()

# Run forward pass.
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
loss = compute_loss(model, batch)
# Run forward pass.
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
loss = compute_loss(model, batch)

# Trigger backward pass.
loss.backward()
# Trigger backward pass.
loss.backward()

# Clip gradient norms.
model.clip_grad_norm_(1.0)
# Clip gradient norms.
model.clip_grad_norm_(1.0)

# Take optimizer step.
optim.step()
# Take optimizer step.
optim.step()

batch_end = time.monotonic()
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_end-batch_start:.3f}",
)
batch_end = time.monotonic()
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_end-batch_start:.3f}",
)

if p is not None:
p.step()

if save_path is not None:
checkpoint_dir = Path(save_path) / "final"
Expand Down Expand Up @@ -126,6 +151,15 @@ def main(
"--debug",
action="store_true",
)
parser.add_argument(
"--profile",
action="store_true",
)
parser.add_argument(
"--trace-output",
type=str,
default="/tmp/traces/olmo_core.chrome_trace.json.gz",
)
parser.add_argument(
"--save-path",
type=str,
Expand Down Expand Up @@ -168,7 +202,7 @@ def main(
raise NotImplementedError(args.model_size)

if args.debug:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
config.debug = True

dist.init_process_group(backend="nccl")
Expand All @@ -185,6 +219,8 @@ def main(
dry_run=args.dry_run,
save_path=args.save_path,
load_path=args.load_path,
profile=args.profile,
trace_output=args.trace_output,
mixed_precision=mixed_precision,
max_prefetch_count=args.max_prefetch_count,
learning_rate=args.lr,
Expand Down
81 changes: 57 additions & 24 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.autograd import Variable

from olmo_core.distributed.tensors import ShardedFlatParameter
from olmo_core.stream import Stream
Expand Down Expand Up @@ -322,7 +323,7 @@ def clip_grad_norm_(self, max_norm: float, norm_type: float = 2.0) -> torch.Tens
nonsharded_params: Set[nn.Parameter] = set()
grads: List[torch.Tensor] = []
for param in self.parameters():
if param.grad is None:
if param.grad is None or param.grad.numel() == 0:
continue

if isinstance(param, ShardedFlatParameter):
Expand Down Expand Up @@ -394,7 +395,11 @@ def _lazy_init(self):
self.state.forward_execution_order.append(self)
return

log.debug("Completing lazy initialization from root FSDP for %s...", self.module.__class__.__name__)
log.debug(
"Completing lazy initialization from root FSDP for %s (%s)...",
self.module.__class__.__name__,
id(self.module),
)

# Initialize streams.
self.state.compute_stream = Stream.default(self.device)
Expand Down Expand Up @@ -494,7 +499,7 @@ def _shard(self):
This should only be called once at initialization.
"""
log.debug("Sharding %s...", self.module.__class__.__name__)
log.debug("Sharding %s (%s)...", self.module.__class__.__name__, id(self.module))

params_with_grads: List[nn.Parameter] = []
params_with_grads_fqns: List[str] = []
Expand Down Expand Up @@ -568,7 +573,7 @@ def _unshard(

kwargs = dict(cast=cast, set_grads=set_grads, recurse=recurse, rank0_only=rank0_only)

log.debug("Unsharding %s...", self.module.__class__.__name__)
log.debug("Unsharding %s (%s)...", self.module.__class__.__name__, id(self.module))
self.state.params_prefetched = True

# NOTE: `unshard_stream` should wait on current stream (usually `compute_stream` / `default_stream`)
Expand Down Expand Up @@ -600,7 +605,11 @@ def _unshard(
def _prefetch(self, prefetch_from: deque[FSDP], **kwargs):
for module in self._deque_from(prefetch_from):
log.debug(
"Prefetching %s from %s...", module.module.__class__.__name__, self.module.__class__.__name__
"Prefetching %s (%s) from %s (%s)...",
module.module.__class__.__name__,
id(module.module),
self.module.__class__.__name__,
id(self.module),
)
module._unshard(**kwargs)

Expand All @@ -611,7 +620,7 @@ def _reshard(self, writeback: bool = False, recurse: bool = False):
"""
kwargs = dict(writeback=writeback, recurse=recurse)

log.debug("Resharding %s...", self.module.__class__.__name__)
log.debug("Resharding %s (%s)...", self.module.__class__.__name__, id(self.module))
self.state.params_prefetched = False

for handle in self.state.flat_param_handles:
Expand All @@ -637,7 +646,7 @@ def _reduce_scatter_grads(self):

grad_reduce_dtype: Optional[torch.dtype] = self.precision.reduce_dtype or self.precision.param_dtype
with self.state.reduce_stream(wait_stream=self.state.current_stream):
log.debug("Reduce-scattering grads for %s", self.module.__class__.__name__)
log.debug("Reduce-scattering grads for %s (%s)", self.module.__class__.__name__, id(self.module))
for handle in self.state.flat_param_handles:
handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)

Expand All @@ -659,13 +668,16 @@ def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None
@torch.no_grad()
def _pre_backward_hook(self, *unused: Any):
del unused
log.debug("Running pre-backward hook for %s...", self.module.__class__.__name__)
log.debug("Running pre-backward hook for %s (%s)...", self.module.__class__.__name__, id(self.module))

# Remove all pre backward hooks for this FSDP instance since they all do the same thing.
for handle in self.state.pre_backward_hook_handles:
handle.remove()
self.state.pre_backward_hook_handles.clear()

if self.is_root:
self._register_post_backward_final_hook()

# Unshard parameters in place.
self._unshard(set_grads=True)

Expand All @@ -684,10 +696,12 @@ def _register_pre_backward_hook(self, x: torch.Tensor):
self.state.pre_backward_hook_handles.append(handle)

def _register_pre_backward_hooks(self, output: Any):
log.debug("Registering pre-backward hooks for %s...", self.module.__class__.__name__)
log.debug("Registering pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module))
# Clear existing hooks if there are any.
if self.state.pre_backward_hook_handles:
log.debug("Removing old pre-backward hooks for %s...", self.module.__class__.__name__)
log.debug(
"Removing old pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
)
for handle in self.state.pre_backward_hook_handles:
handle.remove()
self.state.pre_backward_hook_handles.clear()
Expand All @@ -699,29 +713,19 @@ def _register_pre_backward_hooks(self, output: Any):
@torch.no_grad()
def _post_backward_hook(self, param_name: str, *unused: Any):
del unused
log.debug("Running post-backward hook for %s.%s...", self.module.__class__.__name__, param_name)
self.state.post_backward_hook_handles.pop(param_name).remove()

# If there are still more handles then there are still more post-backward hooks to be ran
# in the current FSDP node. Only the last handle should do the work.
if self.state.post_backward_hook_handles:
return

log.debug("Running post-backward hook for %s (%s)", self.module.__class__.__name__, id(self.module))

# NOTE: reshard *before* reducing grads to correctly handle precision settings.
self._reshard()
self._reduce_scatter_grads()

# The root FSDP instance needs to do some final cleanup.
if not self.is_root:
return

# Mark backward execution order as finalized.
self.state.backward_execution_order_finalized = True

# Wait for unsharding and reducing streams to complete so the model is not left in a bad
# state before grad clipping, optimizer step, or whatever else.
self.state.current_stream.wait_stream(self.state.reduce_stream)

def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParameter):
# Force creation of a `grad_fn` in order to register a hook that will run *after* this param's
# backward pass.
Expand All @@ -733,13 +737,42 @@ def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParame
self.state.post_backward_hook_handles[param_name] = handle

def _register_post_backward_hooks(self):
log.debug("Registering post-backward hooks for %s...", self.module.__class__.__name__)
log.debug(
"Registering post-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
)
# Clear existing hooks if there are any.
if self.state.post_backward_hook_handles:
log.debug("Removing old post-backward hooks for %s...", self.module.__class__.__name__)
log.debug(
"Removing old post-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
)
for handle in self.state.post_backward_hook_handles.values():
handle.remove()
self.state.post_backward_hook_handles.clear()
for param_name, param in self._managed_named_parameters():
if param.requires_grad:
self._register_post_backward_hook(param_name, param)

@torch.no_grad()
def _post_backward_final_hook(self):
if not self.is_root:
return

log.debug("Running post-backward final hook for %s (%s)", self.module.__class__.__name__, id(self.module))

# Mark backward execution order as finalized.
self.state.backward_execution_order_finalized = True
for child in self._fsdp_children(recurse=True):
child.state.backward_execution_order_finalized = True

# Wait for unsharding and reducing streams to complete so the model is not left in a bad
# state before grad clipping, optimizer step, or whatever else.
self.state.current_stream.wait_stream(self.state.reduce_stream)

def _register_post_backward_final_hook(self):
if not self.is_root:
return

log.debug(
"Registering post-backward final hook for %s (%s)...", self.module.__class__.__name__, id(self.module)
)
Variable._execution_engine.queue_callback(self._post_backward_final_hook)

0 comments on commit 2cfd3a7

Please sign in to comment.