diff --git a/src/benchmarks/fsdp/train.py b/src/benchmarks/fsdp/train.py index 9e295819..82d79851 100644 --- a/src/benchmarks/fsdp/train.py +++ b/src/benchmarks/fsdp/train.py @@ -4,8 +4,8 @@ """ import argparse +import contextlib import logging -import os import time from pathlib import Path from typing import Literal, Optional @@ -32,6 +32,8 @@ def main( save_path: Optional[str] = None, load_path: Optional[str] = None, mixed_precision: bool = True, + profile: bool = False, + trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz", **kwargs, ): model, optim, dataloader = build_components( @@ -56,33 +58,56 @@ def main( print_rank0(f"Saving checkpoint to {checkpoint_dir}...") save_model_and_optim_state(checkpoint_dir, model, optim) + profiler = contextlib.nullcontext() + if profile: + from torch.profiler import ProfilerActivity, schedule + + def on_trace_ready(p): + trace_path = Path(trace_output).expanduser() + trace_path.parent.mkdir(exist_ok=True, parents=True) + p.export_chrome_trace(str(trace_path)) + print_rank0(f"Tracing complete, saved to '{trace_path}'") + + profiler = torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + with_stack=True, + schedule=schedule(wait=1, warmup=5, active=3, repeat=1), + on_trace_ready=on_trace_ready, + ) + print_rank0("Starting training...") - for i, batch in enumerate(iter(dataloader)): - log.debug("Batch: %s", batch) - batch_start = time.monotonic() + with profiler as p: + for i, batch in enumerate(iter(dataloader)): + log.debug("Batch: %s", batch) + batch_start = time.monotonic() - # Zero-gradients. - optim.zero_grad() + # Zero-gradients. + optim.zero_grad() - # Run forward pass. - with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision): - loss = compute_loss(model, batch) + # Run forward pass. + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision): + loss = compute_loss(model, batch) - # Trigger backward pass. - loss.backward() + # Trigger backward pass. + loss.backward() - # Clip gradient norms. - model.clip_grad_norm_(1.0) + # Clip gradient norms. + model.clip_grad_norm_(1.0) - # Take optimizer step. - optim.step() + # Take optimizer step. + optim.step() - batch_end = time.monotonic() - print_rank0( - f"Batch [{i+1}/{num_batches}]:\n" - f" loss={loss.item():.3f}\n" - f" throughput/seconds_per_batch={batch_end-batch_start:.3f}", - ) + batch_end = time.monotonic() + print_rank0( + f"Batch [{i+1}/{num_batches}]:\n" + f" loss={loss.item():.3f}\n" + f" throughput/seconds_per_batch={batch_end-batch_start:.3f}", + ) + + if p is not None: + p.step() if save_path is not None: checkpoint_dir = Path(save_path) / "final" @@ -126,6 +151,15 @@ def main( "--debug", action="store_true", ) + parser.add_argument( + "--profile", + action="store_true", + ) + parser.add_argument( + "--trace-output", + type=str, + default="/tmp/traces/olmo_core.chrome_trace.json.gz", + ) parser.add_argument( "--save-path", type=str, @@ -168,7 +202,7 @@ def main( raise NotImplementedError(args.model_size) if args.debug: - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" config.debug = True dist.init_process_group(backend="nccl") @@ -185,6 +219,8 @@ def main( dry_run=args.dry_run, save_path=args.save_path, load_path=args.load_path, + profile=args.profile, + trace_output=args.trace_output, mixed_precision=mixed_precision, max_prefetch_count=args.max_prefetch_count, learning_rate=args.lr, diff --git a/src/olmo_core/distributed/fsdp/fsdp.py b/src/olmo_core/distributed/fsdp/fsdp.py index abcbcc45..a9f47f3f 100644 --- a/src/olmo_core/distributed/fsdp/fsdp.py +++ b/src/olmo_core/distributed/fsdp/fsdp.py @@ -26,6 +26,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.autograd import Variable from olmo_core.distributed.tensors import ShardedFlatParameter from olmo_core.stream import Stream @@ -322,7 +323,7 @@ def clip_grad_norm_(self, max_norm: float, norm_type: float = 2.0) -> torch.Tens nonsharded_params: Set[nn.Parameter] = set() grads: List[torch.Tensor] = [] for param in self.parameters(): - if param.grad is None: + if param.grad is None or param.grad.numel() == 0: continue if isinstance(param, ShardedFlatParameter): @@ -394,7 +395,11 @@ def _lazy_init(self): self.state.forward_execution_order.append(self) return - log.debug("Completing lazy initialization from root FSDP for %s...", self.module.__class__.__name__) + log.debug( + "Completing lazy initialization from root FSDP for %s (%s)...", + self.module.__class__.__name__, + id(self.module), + ) # Initialize streams. self.state.compute_stream = Stream.default(self.device) @@ -494,7 +499,7 @@ def _shard(self): This should only be called once at initialization. """ - log.debug("Sharding %s...", self.module.__class__.__name__) + log.debug("Sharding %s (%s)...", self.module.__class__.__name__, id(self.module)) params_with_grads: List[nn.Parameter] = [] params_with_grads_fqns: List[str] = [] @@ -568,7 +573,7 @@ def _unshard( kwargs = dict(cast=cast, set_grads=set_grads, recurse=recurse, rank0_only=rank0_only) - log.debug("Unsharding %s...", self.module.__class__.__name__) + log.debug("Unsharding %s (%s)...", self.module.__class__.__name__, id(self.module)) self.state.params_prefetched = True # NOTE: `unshard_stream` should wait on current stream (usually `compute_stream` / `default_stream`) @@ -600,7 +605,11 @@ def _unshard( def _prefetch(self, prefetch_from: deque[FSDP], **kwargs): for module in self._deque_from(prefetch_from): log.debug( - "Prefetching %s from %s...", module.module.__class__.__name__, self.module.__class__.__name__ + "Prefetching %s (%s) from %s (%s)...", + module.module.__class__.__name__, + id(module.module), + self.module.__class__.__name__, + id(self.module), ) module._unshard(**kwargs) @@ -611,7 +620,7 @@ def _reshard(self, writeback: bool = False, recurse: bool = False): """ kwargs = dict(writeback=writeback, recurse=recurse) - log.debug("Resharding %s...", self.module.__class__.__name__) + log.debug("Resharding %s (%s)...", self.module.__class__.__name__, id(self.module)) self.state.params_prefetched = False for handle in self.state.flat_param_handles: @@ -637,7 +646,7 @@ def _reduce_scatter_grads(self): grad_reduce_dtype: Optional[torch.dtype] = self.precision.reduce_dtype or self.precision.param_dtype with self.state.reduce_stream(wait_stream=self.state.current_stream): - log.debug("Reduce-scattering grads for %s", self.module.__class__.__name__) + log.debug("Reduce-scattering grads for %s (%s)", self.module.__class__.__name__, id(self.module)) for handle in self.state.flat_param_handles: handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype) @@ -659,13 +668,16 @@ def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None @torch.no_grad() def _pre_backward_hook(self, *unused: Any): del unused - log.debug("Running pre-backward hook for %s...", self.module.__class__.__name__) + log.debug("Running pre-backward hook for %s (%s)...", self.module.__class__.__name__, id(self.module)) # Remove all pre backward hooks for this FSDP instance since they all do the same thing. for handle in self.state.pre_backward_hook_handles: handle.remove() self.state.pre_backward_hook_handles.clear() + if self.is_root: + self._register_post_backward_final_hook() + # Unshard parameters in place. self._unshard(set_grads=True) @@ -684,10 +696,12 @@ def _register_pre_backward_hook(self, x: torch.Tensor): self.state.pre_backward_hook_handles.append(handle) def _register_pre_backward_hooks(self, output: Any): - log.debug("Registering pre-backward hooks for %s...", self.module.__class__.__name__) + log.debug("Registering pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)) # Clear existing hooks if there are any. if self.state.pre_backward_hook_handles: - log.debug("Removing old pre-backward hooks for %s...", self.module.__class__.__name__) + log.debug( + "Removing old pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module) + ) for handle in self.state.pre_backward_hook_handles: handle.remove() self.state.pre_backward_hook_handles.clear() @@ -699,7 +713,6 @@ def _register_pre_backward_hooks(self, output: Any): @torch.no_grad() def _post_backward_hook(self, param_name: str, *unused: Any): del unused - log.debug("Running post-backward hook for %s.%s...", self.module.__class__.__name__, param_name) self.state.post_backward_hook_handles.pop(param_name).remove() # If there are still more handles then there are still more post-backward hooks to be ran @@ -707,21 +720,12 @@ def _post_backward_hook(self, param_name: str, *unused: Any): if self.state.post_backward_hook_handles: return + log.debug("Running post-backward hook for %s (%s)", self.module.__class__.__name__, id(self.module)) + # NOTE: reshard *before* reducing grads to correctly handle precision settings. self._reshard() self._reduce_scatter_grads() - # The root FSDP instance needs to do some final cleanup. - if not self.is_root: - return - - # Mark backward execution order as finalized. - self.state.backward_execution_order_finalized = True - - # Wait for unsharding and reducing streams to complete so the model is not left in a bad - # state before grad clipping, optimizer step, or whatever else. - self.state.current_stream.wait_stream(self.state.reduce_stream) - def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParameter): # Force creation of a `grad_fn` in order to register a hook that will run *after* this param's # backward pass. @@ -733,13 +737,42 @@ def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParame self.state.post_backward_hook_handles[param_name] = handle def _register_post_backward_hooks(self): - log.debug("Registering post-backward hooks for %s...", self.module.__class__.__name__) + log.debug( + "Registering post-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module) + ) # Clear existing hooks if there are any. if self.state.post_backward_hook_handles: - log.debug("Removing old post-backward hooks for %s...", self.module.__class__.__name__) + log.debug( + "Removing old post-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module) + ) for handle in self.state.post_backward_hook_handles.values(): handle.remove() self.state.post_backward_hook_handles.clear() for param_name, param in self._managed_named_parameters(): if param.requires_grad: self._register_post_backward_hook(param_name, param) + + @torch.no_grad() + def _post_backward_final_hook(self): + if not self.is_root: + return + + log.debug("Running post-backward final hook for %s (%s)", self.module.__class__.__name__, id(self.module)) + + # Mark backward execution order as finalized. + self.state.backward_execution_order_finalized = True + for child in self._fsdp_children(recurse=True): + child.state.backward_execution_order_finalized = True + + # Wait for unsharding and reducing streams to complete so the model is not left in a bad + # state before grad clipping, optimizer step, or whatever else. + self.state.current_stream.wait_stream(self.state.reduce_stream) + + def _register_post_backward_final_hook(self): + if not self.is_root: + return + + log.debug( + "Registering post-backward final hook for %s (%s)...", self.module.__class__.__name__, id(self.module) + ) + Variable._execution_engine.queue_callback(self._post_backward_final_hook)