From 963f7dd00c508ff7a62b32294398cd2bbfcde500 Mon Sep 17 00:00:00 2001 From: Pete Date: Fri, 12 Apr 2024 15:34:05 -0700 Subject: [PATCH] FSDP fixes (#8) * check for nans when unsharding * don't cast root forward inputs * use all_gather_into_tensor when possible * fix * fix * fix dtype * clean up * clean up * assert dtypes in `summon_full_params` context * no write back * more tests * explicit cast when checking * updates * ensure `cast` and `writeback` not both set * cast in other direction * revert * Add mp option to train script * update * adjust LR * update prefetching logic in forward pass * update backward prefetch logic * define stream in top-level of package * debugging * add to stream test * more test * clean up * Try recording stream * Add comment * clean up * updates * fix how many mods are prefetched * don't check for nan loss * clean up --- src/benchmarks/fsdp/common.py | 27 +++++- src/benchmarks/fsdp/test.py | 37 ++++++-- src/benchmarks/fsdp/train.py | 35 ++++++-- src/olmo_core/distributed/checkpoint.py | 15 +++- .../distributed/fsdp/flat_param_handle.py | 35 ++++++-- src/olmo_core/distributed/fsdp/fsdp.py | 86 +++++++++++-------- src/olmo_core/distributed/fsdp/state.py | 2 +- .../{distributed/fsdp => }/stream.py | 10 +++ src/test/distributed/fsdp/fsdp_test.py | 2 +- src/test/distributed/utils.py | 63 ++++++-------- .../{distributed/fsdp => }/stream_test.py | 14 ++- src/test/utils.py | 48 +++++++++++ 12 files changed, 272 insertions(+), 102 deletions(-) rename src/olmo_core/{distributed/fsdp => }/stream.py (86%) rename src/test/{distributed/fsdp => }/stream_test.py (61%) create mode 100644 src/test/utils.py diff --git a/src/benchmarks/fsdp/common.py b/src/benchmarks/fsdp/common.py index b7b8d551..e466003d 100644 --- a/src/benchmarks/fsdp/common.py +++ b/src/benchmarks/fsdp/common.py @@ -29,6 +29,7 @@ class TransformerConfig: mlp_ratio: int = 4 max_sequence_length: int = 2048 init_device: torch.device = torch.device("cpu") + debug: bool = False @classmethod def tiniest(cls) -> TransformerConfig: @@ -50,6 +51,7 @@ def medium(cls) -> TransformerConfig: class Transformer(nn.Module): def __init__(self, config: TransformerConfig): super().__init__() + self.config = config self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device) self.wpe = nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device) self.blocks = nn.ModuleList( @@ -82,11 +84,24 @@ def __init__(self, config: TransformerConfig): ) def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.config.debug: + for param in self.parameters(recurse=False): + assert not param.isnan().any() + assert not x.isnan().any() x = self.wte(x) + if self.config.debug: + assert not x.isnan().any() x = x + self.wpe(self.positions) + if self.config.debug: + assert not x.isnan().any() for block in self.blocks: x = block(x, src_mask=self.causal_mask, is_causal=True) - return self.decoder(x) + if self.config.debug: + assert not x.isnan().any() + x = self.decoder(x) + if self.config.debug: + assert not x.isnan().any() + return x class Dataloader: @@ -135,6 +150,8 @@ def build_components( fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core", wrap_blocks: bool = True, mixed_precision: bool = True, + max_prefetch_count: int = 1, + learning_rate: float = 1e-4, ) -> Tuple[nn.Module, torch.optim.Optimizer, Dataloader]: model = Transformer(config) @@ -148,6 +165,7 @@ def build_components( precision=FSDPPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) if mixed_precision else None, + max_prefetch_count=max_prefetch_count, ) model.apply(init_function) @@ -164,7 +182,10 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool: model = FullyShardedDataParallel( model, mixed_precision=MixedPrecision( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + cast_root_forward_inputs=False, ) if mixed_precision else None, @@ -180,7 +201,7 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool: print_rank0(model) print_rank0("Initializing optimizer...") - optim = torch.optim.AdamW(model.parameters(), lr=1e-5) + optim = torch.optim.AdamW(model.parameters(), lr=learning_rate) return model, optim, Dataloader(batch_size, config, num_batches=num_batches) diff --git a/src/benchmarks/fsdp/test.py b/src/benchmarks/fsdp/test.py index 80a781fb..71d08165 100644 --- a/src/benchmarks/fsdp/test.py +++ b/src/benchmarks/fsdp/test.py @@ -61,15 +61,35 @@ def main( load_model_and_optim_state(checkpoint_dir, olmo_model, olmo_optim) print_rank0("Checking state dict...") - with TorchFSDP.summon_full_params(torch_model), olmo_model.summon_full_params(): - torch_state_dict = {k.replace("_fsdp_wrapped_module.", ""): v for k, v in torch_model.state_dict().items()} - olmo_state_dict = olmo_model.state_dict() - assert torch_state_dict.keys() == olmo_state_dict.keys() - for key in torch_state_dict: + with TorchFSDP.summon_full_params(torch_model, writeback=False), olmo_model.summon_full_params( + writeback=False + ): + torch_fp32_state_dict = { + k.replace("_fsdp_wrapped_module.", ""): v for k, v in torch_model.state_dict().items() + } + olmo_fp32_state_dict = olmo_model.state_dict() + assert torch_fp32_state_dict.keys() == olmo_fp32_state_dict.keys() + for key in torch_fp32_state_dict: + assert torch_fp32_state_dict[key].dtype == torch.float32 + assert olmo_fp32_state_dict[key].dtype == torch.float32 torch.testing.assert_close( - torch_state_dict[key], olmo_state_dict[key], msg=lambda msg: f"Failure for {key}: {msg}" + torch_fp32_state_dict[key], olmo_fp32_state_dict[key], msg=lambda msg: f"Failure for {key}: {msg}" ) + if mixed_precision: + print_rank0("Checking gathering full params in low precision...") + with olmo_model.summon_full_params(cast=True, writeback=False): + olmo_bf16_state_dict = olmo_model.state_dict() + assert olmo_bf16_state_dict.keys() == olmo_fp32_state_dict.keys() + for key in olmo_bf16_state_dict.keys(): + torch.testing.assert_close( + olmo_bf16_state_dict[key], + olmo_fp32_state_dict[key].to(torch.bfloat16), + msg=lambda msg: f"Failure for {key}: {msg}", + rtol=1.3e-6, + atol=1e-5, + ) + if dry_run: print_rank0("Dry run complete") return @@ -84,6 +104,11 @@ def main( olmo_logits = olmo_model(batch1) torch_loss = compute_loss(torch_model, batch1, logits=torch_logits) olmo_loss = compute_loss(olmo_model, batch1, logits=olmo_logits) + + if mixed_precision: + assert torch_logits.dtype == torch.bfloat16 + assert olmo_logits.dtype == torch.bfloat16 + torch.testing.assert_close(olmo_logits, torch_logits) torch.testing.assert_close(olmo_loss, torch_loss) diff --git a/src/benchmarks/fsdp/train.py b/src/benchmarks/fsdp/train.py index ecfc792d..f82c9b04 100644 --- a/src/benchmarks/fsdp/train.py +++ b/src/benchmarks/fsdp/train.py @@ -31,9 +31,16 @@ def main( dry_run: bool = False, save_path: Optional[str] = None, load_path: Optional[str] = None, + mixed_precision: bool = True, + **kwargs, ): model, optim, dataloader = build_components( - config, batch_size, num_batches=num_batches, fsdp_wrapper=fsdp_wrapper + config, + batch_size, + num_batches=num_batches, + fsdp_wrapper=fsdp_wrapper, + mixed_precision=mixed_precision, + **kwargs, ) if load_path is not None: @@ -58,13 +65,11 @@ def main( optim.zero_grad() # Run forward pass. - with torch.autocast("cuda", dtype=torch.bfloat16): + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision): loss = compute_loss(model, batch) # Trigger backward pass. loss.backward() - if not torch.isfinite(loss): - raise ValueError("NaN loss encountered.") # Clip gradient norms. model.clip_grad_norm_(1.0) @@ -76,7 +81,7 @@ def main( print_rank0( f"Batch [{i+1}/{num_batches}]:\n" f" loss={loss.item():.3f}\n" - f" throughput/seconds_per_batch={batch_end-batch_start:.1f}", + f" throughput/seconds_per_batch={batch_end-batch_start:.3f}", ) if save_path is not None: @@ -129,8 +134,24 @@ def main( "--load-path", type=str, ) + parser.add_argument( + "--no-mixed-precision", + action="store_true", + ) + parser.add_argument( + "--max-prefetch-count", + type=int, + default=1, + ) + parser.add_argument( + "--lr", + type=float, + default=1e-4, + ) args = parser.parse_args() + mixed_precision = not args.no_mixed_precision + config: TransformerConfig if args.model_size == "tiny": config = TransformerConfig.tiny() @@ -143,6 +164,7 @@ def main( if args.debug: os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + config.debug = True dist.init_process_group(backend="nccl") torch.cuda.set_device(dist.get_rank()) @@ -158,4 +180,7 @@ def main( dry_run=args.dry_run, save_path=args.save_path, load_path=args.load_path, + mixed_precision=mixed_precision, + max_prefetch_count=args.max_prefetch_count, + learning_rate=args.lr, ) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 2c00b363..876d9b8b 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -501,6 +501,11 @@ def unshard( # Load the state dict in place. self.load(dir, state_dict, metadata=metadata, no_dist=no_dist or rank0_only) + # Check for NaNs which would indicate we didn't fill the state dict correctly. + for key, tensor in state_dict.items(): + if tensor.isnan().any().item(): + raise RuntimeError("error loading {key} from checkpoint, nans encountered") + return state_dict def get_metadata(self, dir: str, no_dist: bool = False) -> StorageMetadata: @@ -665,7 +670,10 @@ def torch_dtype(self) -> torch.dtype: def materialize_empty( self, *, device: Optional[torch.device] = None, shape: Optional[Tuple[int, ...]] = None ) -> torch.Tensor: - return torch.empty(shape if shape is not None else self.shape, dtype=self.torch_dtype, device=device) + tensor = torch.empty(shape if shape is not None else self.shape, dtype=self.torch_dtype, device=device) + if tensor.dtype.is_floating_point: + tensor.fill_(torch.nan) + return tensor def materialize_from_sharded( self, tensor: torch.Tensor, device: Optional[torch.device] = None @@ -675,7 +683,10 @@ def materialize_from_sharded( raise ValueError( f"unexpected shape for sharded tensor, expected {self.shape}, got {tensor.unsharded_shape}" ) - return torch.empty(tensor.shape, device=device, dtype=self.torch_dtype) + tensor = torch.empty(tensor.shape, device=device, dtype=self.torch_dtype) + if tensor.dtype.is_floating_point: + tensor.fill_(torch.nan) + return tensor else: raise NotImplementedError(f"`materialize_from_sharded()` not implemented for {tensor}") diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py index d1324f37..08254d45 100644 --- a/src/olmo_core/distributed/fsdp/flat_param_handle.py +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -32,7 +32,7 @@ class FlatParamHandle: The FQNs of the managed params. """ - grads: List[Optional[torch.Tensor]] = field(default_factory=list) + grads_cache: List[Optional[torch.Tensor]] = field(default_factory=list) """ Used for caching gradients during gradient accumulation. """ @@ -121,7 +121,7 @@ def collate_flat_params( return cls( params=params, param_fqns=list(param_fqns), - grads=[None] * len(params), + grads_cache=[None] * len(params), params_data=params_data, params_offsets_per_rank=params_offsets_per_rank, process_group=process_group, @@ -134,9 +134,27 @@ def unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False """ if not self.params: return + local_rank = get_rank(self.process_group) world_size = get_world_size(self.process_group) - all_params_unsharded_data = self.params_data.gather(dtype=dtype, rank0_only=rank0_only) + + # Gather full, padded, unsharded data for all params. + all_params_unsharded_data: torch.Tensor + if rank0_only or dist.get_backend() == dist.Backend.GLOO: + all_params_unsharded_data = self.params_data.gather(dtype=dtype, rank0_only=rank0_only) + else: + # We prefer to use `all_gather_into_tensor()` directly when possible as it involves + # fewer allocations. + all_params_unsharded_data = torch.empty( + self.params_data.unsharded_shape, dtype=dtype or self.params_data.dtype, device=self.device + ) + dist.all_gather_into_tensor( + all_params_unsharded_data, + self.params_data.data.to(dtype or self.params_data.dtype), + group=self.process_group, + ) + + # Set the data for each param as a view into `all_params_unsharded_data`. for i, (param, param_offsets) in enumerate(zip(self.params, self.params_offsets_per_rank)): if rank0_only and local_rank != 0: param.unshard_( @@ -163,8 +181,8 @@ def unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False # We should only be caching these between the pre-backward and post-backward # hooks. The post-backward hook will remove the cached grad as it accumulates # it into the persistent sharded grad. - assert self.grads[i] is None - self.grads[i] = param.grad.data + assert self.grads_cache[i] is None + self.grads_cache[i] = param.grad.data param.grad = None del all_params_unsharded_data @@ -194,7 +212,8 @@ def reduce_scatter_grads( if grad_dtype is None: grad_dtype = param.dtype - # TODO: batch reductions together + # TODO: batch reductions together? This is complicated, especially if we want to allow + # a mixture of trainable and frozen params. # Only NCCL supports 'reduce_scatter'. So with other backends we use 'all_reduce'. if dist.get_backend() == dist.Backend.NCCL: @@ -209,7 +228,7 @@ def reduce_scatter_grads( del unsharded_grad - if (cached_grad := self.grads[i]) is not None: + if (cached_grad := self.grads_cache[i]) is not None: param.grad.add_(cached_grad) - self.grads[i] = None + self.grads_cache[i] = None del cached_grad diff --git a/src/olmo_core/distributed/fsdp/fsdp.py b/src/olmo_core/distributed/fsdp/fsdp.py index e8ba15bd..80b85ce1 100644 --- a/src/olmo_core/distributed/fsdp/fsdp.py +++ b/src/olmo_core/distributed/fsdp/fsdp.py @@ -28,11 +28,11 @@ import torch.nn as nn from olmo_core.distributed.tensors import ShardedFlatParameter +from olmo_core.stream import Stream from olmo_core.utils import apply_to_tensors, gc_cuda, get_default_device, get_grad_norm from .flat_param_handle import FlatParamHandle from .state import FSDPState -from .stream import Stream log = logging.getLogger(__name__) @@ -158,23 +158,25 @@ def forward(self, *args, **kwargs): """ self._lazy_init() - log.debug("Running forward pass for %s...", self.module.__class__.__name__) - if self.is_root and self.state.forward_execution_order_finalized: # Fill forward-pass prefetch queue for unsharding. for module in self.state.forward_execution_order: self.state.forward_prefetch_queue.append(module) # Unshard parameters in-place. - self._unshard( - prefetch_from=self.state.forward_prefetch_queue - if self.state.forward_execution_order_finalized - else None - ) + self._unshard() try: - # Run forward pass on the original model. - with self.state.compute_stream(wait_stream=self.state.unshard_stream): + # Wait for unsharding stream before running the wrapped module's forward pass. + self.state.compute_stream.wait_stream(self.state.unshard_stream) + + # Then we can prefetch the next FSDP module(s) asynchronously. + if self.state.forward_execution_order_finalized: + self._prefetch(self.state.forward_prefetch_queue) + + # Run forward pass on the wrapped module. + with self.state.compute_stream: + log.debug("Running forward pass for %s...", self.module.__class__.__name__) output = self.module(*args, **kwargs) if torch.is_grad_enabled(): @@ -267,7 +269,9 @@ def named_parameters(self, *args, **kwargs): yield key_mapping.get(name, name), param @contextmanager - def summon_full_params(self, recurse: bool = True, writeback: bool = True, rank0_only: bool = False): + def summon_full_params( + self, recurse: bool = True, writeback: bool = True, rank0_only: bool = False, cast: bool = False + ): """ Gather full unsharded params in-place with this context manager. @@ -275,14 +279,18 @@ def summon_full_params(self, recurse: bool = True, writeback: bool = True, rank0 :param writeback: Write the unsharded data back from rank 0 to all other ranks while exiting the context manager. :param rank0_only: Only summon full params on rank 0. + :param cast: If using a mixed-precision strategy, params are cast to the same dtype as they + are during the forward and backward passes. If this is ``True``, ``writeback`` must be + ``False``. """ - self._unshard(cast=False, recurse=recurse, rank0_only=rank0_only) + if cast and writeback: + raise ValueError("`summon_full_params` with `cast=True` and `writeback=True` is not supported") + self._unshard(cast=cast, recurse=recurse, rank0_only=rank0_only) self.state.current_stream.wait_stream(self.state.unshard_stream) try: yield self finally: self._reshard(writeback=writeback, recurse=recurse) - self.state.current_stream.wait_stream(self.state.unshard_stream) def apply(self, fn): """ @@ -514,7 +522,6 @@ def _unshard( cache_grads: bool = False, recurse: bool = False, rank0_only: bool = False, - prefetch_from: Optional[deque[FSDP]] = None, ): """ Unshard the wrapped module in place. @@ -535,17 +542,17 @@ def _unshard( dtype=self.precision.param_dtype if cast else None, rank0_only=rank0_only, cache_grads=cache_grads ) - if prefetch_from is not None: - for module in self._deque_from(prefetch_from): - log.debug( - "Prefetching %s from %s...", module.module.__class__.__name__, self.module.__class__.__name__ - ) - module._unshard(**kwargs) - if recurse: for module in self._fsdp_children(): module._unshard(**kwargs) + def _prefetch(self, prefetch_from: deque[FSDP], **kwargs): + for module in self._deque_from(prefetch_from): + log.debug( + "Prefetching %s from %s...", module.module.__class__.__name__, self.module.__class__.__name__ + ) + module._unshard(**kwargs) + @torch.no_grad() def _reshard(self, writeback: bool = False, recurse: bool = False): """ @@ -556,8 +563,7 @@ def _reshard(self, writeback: bool = False, recurse: bool = False): log.debug("Resharding %s...", self.module.__class__.__name__) self.state.params_prefetched = False - with self.state.unshard_stream(wait_stream=self.state.compute_stream): - self.state.flat_param_handle.reshard_(writeback=writeback) + self.state.flat_param_handle.reshard_(writeback=writeback) if recurse: for module in self._fsdp_children(): @@ -580,13 +586,22 @@ def _reduce_scatter_grads(self): # dtype just for reducing gradients. grad_reduce_dtype: Optional[torch.dtype] = self.precision.reduce_dtype or self.precision.param_dtype + og_grads = [param.grad for param in self.state.flat_param_handle.params if param.grad is not None] + with self.state.reduce_stream(wait_stream=self.state.current_stream): log.debug("Reduce-scattering grads for %s", self.module.__class__.__name__) self.state.flat_param_handle.reduce_scatter_grads(grad_reduce_dtype=grad_reduce_dtype) + # Reduce-scattering the grads relies on the original (local) grads of course, + # which are produced in the current stream being used for the backwards pass. + # Since we're using a separate stream for the reduce-scatter, we need to make sure those + # grads are not deallocated before the reduce-scatter finishes. + for og_grad in og_grads: + self.state.reduce_stream.record_for(og_grad) + def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None]: count = 0 - while prefetch_queue and count <= self.max_prefetch_count: + while prefetch_queue and count < self.max_prefetch_count: module = prefetch_queue.popleft() if module is not self: count += 1 @@ -604,26 +619,24 @@ def _pre_backward_hook(self, *unused: Any): del unused log.debug("Running pre-backward hook for %s...", self.module.__class__.__name__) - if not self.state.backward_execution_order_finalized: - # Add self to backward execution order. - self.state.backward_execution_order.append(self) - - # Unshard parameters in place. - self._unshard( - cache_grads=True, - prefetch_from=self.state.backward_prefetch_queue - if self.state.backward_execution_order_finalized - else None, - ) - # Remove all pre backward hooks for this FSDP instance since they all do the same thing. for handle in self.state.pre_backward_hook_handles: handle.remove() self.state.pre_backward_hook_handles.clear() + # Unshard parameters in place. + self._unshard(cache_grads=True) + # Wait for unshard stream so gradient computation can proceed. self.state.current_stream.wait_stream(self.state.unshard_stream) + if self.state.backward_execution_order_finalized: + # Prefetch next FSDP module(s) asynchronously. + self._prefetch(self.state.backward_prefetch_queue, cache_grads=True) + else: + # Add self to backward execution order. + self.state.backward_execution_order.append(self) + def _register_pre_backward_hook(self, x: torch.Tensor): handle = x.register_hook(self._pre_backward_hook) self.state.pre_backward_hook_handles.append(handle) @@ -667,7 +680,6 @@ def _post_backward_hook(self, param_name: str, *unused: Any): # Wait for unsharding and reducing streams to complete so the model is not left in a bad # state before grad clipping, optimizer step, or whatever else. - self.state.current_stream.wait_stream(self.state.unshard_stream) self.state.current_stream.wait_stream(self.state.reduce_stream) def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParameter): diff --git a/src/olmo_core/distributed/fsdp/state.py b/src/olmo_core/distributed/fsdp/state.py index 819339b0..51c0517c 100644 --- a/src/olmo_core/distributed/fsdp/state.py +++ b/src/olmo_core/distributed/fsdp/state.py @@ -7,10 +7,10 @@ import torch from torch.utils.hooks import RemovableHandle +from olmo_core.stream import Stream from olmo_core.utils import get_default_device from .flat_param_handle import FlatParamHandle -from .stream import Stream if TYPE_CHECKING: from .fsdp import FSDP diff --git a/src/olmo_core/distributed/fsdp/stream.py b/src/olmo_core/stream.py similarity index 86% rename from src/olmo_core/distributed/fsdp/stream.py rename to src/olmo_core/stream.py index 78ceaab7..6f6c2a28 100644 --- a/src/olmo_core/distributed/fsdp/stream.py +++ b/src/olmo_core/stream.py @@ -51,6 +51,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def wait_stream(self, other: Stream): del other + def record_for(self, tensor: torch.Tensor): + del tensor + class CudaStream(Stream): def __init__(self, base_stream: torch.cuda.Stream): @@ -69,6 +72,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): def wait_stream(self, other: Stream): if isinstance(other, CudaStream): self.base_stream.wait_stream(other.base_stream) + elif isinstance(other, torch.cuda.Stream): + self.base_stream.wait_stream(other) + elif not isinstance(other, Stream): + raise ValueError(f"expected a Stream, got {type(other)}") + + def record_for(self, tensor: torch.Tensor): + tensor.record_stream(self.base_stream) class CpuStream(Stream): diff --git a/src/test/distributed/fsdp/fsdp_test.py b/src/test/distributed/fsdp/fsdp_test.py index b9e4d646..75e1d2cc 100644 --- a/src/test/distributed/fsdp/fsdp_test.py +++ b/src/test/distributed/fsdp/fsdp_test.py @@ -159,7 +159,7 @@ def run_fsdp_against_ddp(model_factory, model_data_factory): # Since we've only done a single backwards pass (no grad accumulation), there shouldn't # be any cached gradients. - for cached_grad in fsdp_model.state.flat_param_handle.grads: + for cached_grad in fsdp_model.state.flat_param_handle.grads_cache: assert cached_grad is None # Run optimizer step. diff --git a/src/test/distributed/utils.py b/src/test/distributed/utils.py index 2ad957bf..13a7c98d 100644 --- a/src/test/distributed/utils.py +++ b/src/test/distributed/utils.py @@ -11,20 +11,37 @@ from olmo_core.distributed.fsdp import FSDPPrecision from olmo_core.distributed.utils import is_distributed -has_cuda = torch.cuda.is_available() -has_multiple_gpus = has_cuda and torch.cuda.device_count() > 1 +from ..utils import ( + DEVICES, + GPU_MARKS, + INIT_DEVICES, + LOW_PRECISION_DTYPES, + has_cuda, + requires_gpu, +) + +__all__ = [ + "has_cuda", + "has_multiple_gpus", + "requires_gpu", + "requires_multi_gpu", + "get_default_device", + "init_process", + "run_distributed_test", + "DEVICES", + "INIT_DEVICES", + "BACKENDS", + "LOW_PRECISION_DTYPES", + "FSDP_MIXED_PRECISION", + "GPU_MARKS", + "MULTI_GPU_MARKS", +] -GPU_MARKS = (pytest.mark.gpu, pytest.mark.skipif(not has_cuda, reason="Requires a GPU")) +has_multiple_gpus = has_cuda and torch.cuda.device_count() > 1 MULTI_GPU_MARKS = (pytest.mark.gpu, pytest.mark.skipif(not has_multiple_gpus, reason="Requires multiple GPUs")) -def requires_gpu(func): - for mark in GPU_MARKS: - func = mark(func) - return func - - def requires_multi_gpu(func): for mark in MULTI_GPU_MARKS: func = mark(func) @@ -40,34 +57,6 @@ def requires_multi_gpu(func): ), ] -DEVICES = [ - pytest.param(torch.device("cpu"), id="device=CPU"), - pytest.param( - torch.device("cuda"), - id="device=CUDA", - marks=GPU_MARKS, - ), -] - -INIT_DEVICES = [ - pytest.param(torch.device("meta"), id="device=meta"), - pytest.param(torch.device("cpu"), id="device=CPU"), - pytest.param( - torch.device("cuda"), - id="device=CUDA", - marks=GPU_MARKS, - ), -] - -LOW_PRECISION_DTYPES = [ - pytest.param(torch.float16, id="dtype=float16"), - pytest.param( - torch.bfloat16, - id="dtype=bfloat16", - marks=GPU_MARKS, - ), -] - FSDP_MIXED_PRECISION = [ pytest.param(FSDPPrecision(param_dtype=torch.float16, reduce_dtype=None), id="param_dtype=FP16"), pytest.param( diff --git a/src/test/distributed/fsdp/stream_test.py b/src/test/stream_test.py similarity index 61% rename from src/test/distributed/fsdp/stream_test.py rename to src/test/stream_test.py index c39ece15..dd8aa1df 100644 --- a/src/test/distributed/fsdp/stream_test.py +++ b/src/test/stream_test.py @@ -1,8 +1,8 @@ import torch -from olmo_core.distributed.fsdp.stream import CudaStream, Stream +from olmo_core.stream import CudaStream, Stream -from ..utils import requires_gpu +from .utils import requires_gpu @requires_gpu @@ -25,6 +25,16 @@ def test_cuda_stream(): with other_stream: assert torch.cuda.current_stream(device) == other_stream.base_stream y = torch.sum(x) + assert torch.cuda.current_stream(device) == default_stream.base_stream + + default_stream.wait_stream(other_stream) + del x, y + + x = torch.empty((100, 100), device=device).normal_(0.0, 1.0) + with other_stream(wait_stream=default_stream): + assert torch.cuda.current_stream(device) == other_stream.base_stream + y = torch.sum(x) + assert torch.cuda.current_stream(device) == default_stream.base_stream default_stream.wait_stream(other_stream) del x, y diff --git a/src/test/utils.py b/src/test/utils.py new file mode 100644 index 00000000..fdab2de9 --- /dev/null +++ b/src/test/utils.py @@ -0,0 +1,48 @@ +import pytest +import torch + +has_cuda = torch.cuda.is_available() + +GPU_MARKS = (pytest.mark.gpu, pytest.mark.skipif(not has_cuda, reason="Requires a GPU")) + + +def requires_gpu(func): + for mark in GPU_MARKS: + func = mark(func) + return func + + +INIT_DEVICES = [ + pytest.param(torch.device("meta"), id="device=meta"), + pytest.param(torch.device("cpu"), id="device=CPU"), + pytest.param( + torch.device("cuda"), + id="device=CUDA", + marks=GPU_MARKS, + ), +] + +DEVICES = [ + pytest.param(torch.device("cpu"), id="device=CPU"), + pytest.param( + torch.device("cuda"), + id="device=CUDA", + marks=GPU_MARKS, + ), +] + +LOW_PRECISION_DTYPES = [ + pytest.param(torch.float16, id="dtype=float16"), + pytest.param( + torch.bfloat16, + id="dtype=bfloat16", + marks=GPU_MARKS, + ), +] + + +def get_default_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + else: + return torch.device("cpu")