diff --git a/Makefile b/Makefile index 9a8b272f..0c7195a7 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ style-check : black --check . .PHONY : lint-check -format-check : +lint-check : ruff check . .PHONY : type-check diff --git a/docs/source/distributed/tensors.rst b/docs/source/distributed/tensors.rst new file mode 100644 index 00000000..ea2af14b --- /dev/null +++ b/docs/source/distributed/tensors.rst @@ -0,0 +1,6 @@ +``distributed.tensors`` +======================= + +.. automodule:: olmo_core.distributed.tensors + :members: + :member-order: bysource diff --git a/docs/source/index.rst b/docs/source/index.rst index fecda678..22b7eadc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,10 +14,11 @@ :caption: API Reference exceptions.rst - utils.rst io.rst + utils.rst distributed/checkpoint.rst distributed/fsdp.rst + distributed/tensors.rst .. toctree:: :hidden: diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 7d5380d4..bc05ef6d 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -63,7 +63,7 @@ ) from olmo_core.utils import TORCH_DTYPE_TO_STR, TORCH_DTYPES -from .sharded_flat_tensor import ShardedFlatTensor, ShardingSpec +from .tensors import ShardedFlatTensor, ShardingSpec from .utils import all_gather_object, barrier, get_rank, get_world_size, scatter_object log = logging.getLogger(__name__) diff --git a/src/olmo_core/distributed/fsdp/__init__.py b/src/olmo_core/distributed/fsdp/__init__.py index 0bfb0720..e94deda3 100644 --- a/src/olmo_core/distributed/fsdp/__init__.py +++ b/src/olmo_core/distributed/fsdp/__init__.py @@ -1,10 +1,77 @@ """ -This is a light-weight rewrite of PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel` +This is a light-weight, experimental rewrite of PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel` with a few of improvements, including: - Well-defined "hands off" handling of buffers. FSDP never shards buffers, they are left as-is. - Well-defined handling of frozen params. You can mix and match within an FSDP instance as long - as the frozen/non-frozen params are consistent across the process group. + as you're consistent across the process group with which parameters are frozen. +- Full support for CPU-only training and inference via the GLOO backend. +- Low-overhead checkpointing with :mod:`olmo_core.distributed.checkpoint`. + +Usage Tips +---------- + +- Always initialize your optimizer *after* wrapping your model with FSDP. +- When you use initialize model (prior to wrapping with FSDP), use ``device=torch.device("meta")`` + when initializing *parameters* to save memory. :class:`FSDP` will automatically materialize and + move parameters to the right device when wrapping. + Then you can use :meth:`FSDP.apply()` to initialize parameters how you want. +- Analogous to with PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel`, you should + use :func:`FSDP.clip_grad_norm_()` for clipping gradient norms instead of :func:`torch.nn.utils.clip_grad_norm_()`. +- Use activation checkpointing via :func:`torch.utils.checkpoint.checkpoint()` to save more memory + during the forward and backward pass at the expense of more computation. +- To save and load checkpoints for your FSDP model and its optimizer, use + :func:`~olmo_core.distributed.checkpoint.save_model_and_optim_state()` and + :func:`~olmo_core.distributed.checkpoint.load_model_and_optim_state()`, respectively. + +Implementation Details +---------------------- + +When you wrap a :class:`~torch.nn.Module` with :class:`FSDP`, the wrapping FSDP instance will replace +each original parameter in the module with a :class:`~olmo_core.distributed.tensors.ShardedFlatParameter` instance, +and each rank will only keep a shard of the original data. Buffers are left as-is. + +.. note:: + Further, the sharded data for all of the :class:`~olmo_core.distributed.tensors.ShardedFlatParameter` + instances will be collected into a single :class:`FlatParamHandle`, and each flat parameter will + just hold a view into a slice of the data managed by the handle. This makes gathering the full + params more efficient as it only requires a single all-gather per FSDP node. + +Forward Pass +~~~~~~~~~~~~ + +When the :meth:`~torch.nn.Module.forward()` method is called on the wrapping FSDP instance, it will gather +the full unsharded data for each parameter in the desired :class:`~torch.dtype` +(as defined by the :class:`FSDPPrecision` settings) while caching the sharded data behind the scenes. +Then it runs the forward method of the wrapped module, which is completely unsharded at that point. + +After the forward method of the wrapped module returns, the wrapping FSDP instance will reshard +the parameters and, if gradients are enabled, register backward hooks to manage the state of parameters +and gradients during the backward pass. + +During the first forward pass the root FSDP instance will also record the order of execution of all +FSDP children, and use that order to prefetch the full parameters for its FSDP children during +subsequent forward passes. The number of children that are prefetched at once is controlled by the +``max_prefetch_count`` setting. + +.. note:: + When CUDA is available :class:`FSDP` instances utilize multiple CUDA streams in order to overlap + communication (e.g. unsharding params or reducing gradients) with computation + (e.g. the forward pass or computing gradients during the backward pass). + +Backward Pass +~~~~~~~~~~~~~ + +At the end of the forward method, the wrapping FSDP instance registers ephemeral "pre-backward" and "post-backward" hooks +to unshard the parameters and reduce-scatter the gradients, respectively, during the backward pass. + +At the end of the backward pass the :attr:`~torch.Tensor.grad` attribute of each (non-frozen) parameter will +be the shard of the full gradient corresponding to the shard of the full parameter, i.e. it will +have the same shape/size as the sharded parameter. + +Just how the root FSDP instance records the execution order of its FSDP children during the first +forward pass, the root will also record the order during the first backward pass and use that +to prefetch the full parameters of its children during subsequent backward passes. API Reference ------------- diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py new file mode 100644 index 00000000..2168523a --- /dev/null +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +import torch.distributed as dist + +from olmo_core.distributed.tensors import ( + ShardedFlatParameter, + ShardedFlatTensor, + ShardingSpec, +) +from olmo_core.distributed.utils import get_rank, get_world_size +from olmo_core.utils import get_default_device + + +@dataclass +class FlatParamHandle: + """ + Manages the data for a group of sharded flat parameters in order to use a single all-reduce + to unshard all of the parameters at once. + """ + + params: List[ShardedFlatParameter] = field(default_factory=list) + """ + The params managed by this handle. + """ + + param_fqns: List[str] = field(default_factory=list) + """ + The FQNs of the managed params. + """ + + grads: List[Optional[torch.Tensor]] = field(default_factory=list) + """ + Used for caching gradients during gradient accumulation. + """ + + params_data: ShardedFlatTensor = field(default_factory=lambda: ShardedFlatTensor(torch.empty(0))) + """ + Consolidated data for all of the local sharded data of the parameters. + """ + + params_offsets_per_rank: List[Dict[int, Tuple[int, int]]] = field(default_factory=list) + """ + For each parameter, provides a mapping of rank to the offsets into the rank's local `params_data` + for that parameter. + """ + + process_group: Optional[dist.ProcessGroup] = None + + device: Optional[torch.device] = None + + @classmethod + def collate_flat_params( + cls, + params: Iterable[ShardedFlatParameter], + param_fqns: Iterable[str], + process_group: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> FlatParamHandle: + """ + Collate the data from a group of sharded flat parameters into a single flat param handle. + """ + device = device or get_default_device() + params = list(params) + world_size = get_world_size(process_group) + local_rank = get_rank(process_group) + + if not params: + return cls(process_group=process_group) + + # Find max numel of all sharded flat params across the process group to determine padding. + # All ranks will have the same sized `params_data` at the end to avoid needed padding at runtime. + numel_total_per_rank: List[int] = [0] * world_size + for param in params: + if not param.is_sharded: + raise ValueError("All sharded flat params should be sharded at this point!") + if param.dtype != torch.float32: + raise NotImplementedError("Only float32 params are supported at this time") + for rank, n in enumerate(param.sharding_spec.sharded_numels): + numel_total_per_rank[rank] += n + max_numel = max(numel_total_per_rank) + + # Initialize local data for all params. + params_data = ShardedFlatTensor(torch.empty(max_numel, device=device)) + params_data.mark_as_sharded( + ShardingSpec( + unsharded_shape=(world_size, max_numel), + unsharded_flattened_offsets=tuple( + [ + (start_idx, end_idx) + for start_idx, end_idx in zip( + range(0, max_numel * world_size, max_numel), + range(max_numel, max_numel * world_size + 1, max_numel), + ) + ] + ), + ), + process_group=process_group, + ) + + # Consolidate the sharded data from each param into `params_data` and collect offsets. + params_offsets_per_rank: List[Dict[int, Tuple[int, int]]] = [] + offset_start_per_rank = {rank: 0 for rank in range(world_size)} + for param in params: + param_offsets: Dict[int, Tuple[int, int]] = {} + for rank in range(world_size): + offset_start = offset_start_per_rank[rank] + offset_end = offset_start + param.sharding_spec.sharded_numels[rank] + param_offsets[rank] = (offset_start, offset_end) + offset_start_per_rank[rank] = offset_end + params_offsets_per_rank.append(param_offsets) + + # Set data for param as a view into `params_data`. + offset_start, offset_end = param_offsets[local_rank] + params_data.data[offset_start:offset_end].copy_(param.data) + param.data = params_data.data[offset_start:offset_end] + + return cls( + params=params, + param_fqns=list(param_fqns), + grads=[None] * len(params), + params_data=params_data, + params_offsets_per_rank=params_offsets_per_rank, + process_group=process_group, + device=device, + ) + + def unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False, cache_grads: bool = False): + """ + Unshard the handle's managed flat parameters in-place. + """ + if not self.params: + return + local_rank = get_rank(self.process_group) + world_size = get_world_size(self.process_group) + all_params_unsharded_data = self.params_data.gather(dtype=dtype, rank0_only=rank0_only) + for i, (param, param_offsets) in enumerate(zip(self.params, self.params_offsets_per_rank)): + if rank0_only and local_rank != 0: + param.unshard_( + unsharded_data=torch.empty_like(all_params_unsharded_data), dtype=dtype, rank0_only=rank0_only + ) + else: + unsharded_data = torch.empty( + param.sharding_spec.unsharded_flattened_shape, + dtype=all_params_unsharded_data.dtype, + device=self.device, + ) + for rank in range(world_size): + rank_local_data = all_params_unsharded_data[rank][ + param_offsets[rank][0] : param_offsets[rank][1] + ] + unsharded_data[ + param.sharding_spec.unsharded_flattened_offsets[rank][ + 0 + ] : param.sharding_spec.unsharded_flattened_offsets[rank][1] + ] = rank_local_data + param.unshard_(unsharded_data=unsharded_data, dtype=dtype) + + if cache_grads and param.grad is not None: + # We should only be caching these between the pre-backward and post-backward + # hooks. The post-backward hook will remove the cached grad as it accumulates + # it into the persistent sharded grad. + assert self.grads[i] is None + self.grads[i] = param.grad.detach() + param.grad = None + + del all_params_unsharded_data + + def reshard_(self, writeback: bool = False): + """ + Reshard the handle's managed flat parameters in-place. + """ + local_rank = get_rank(self.process_group) + for param, param_offsets in zip(self.params, self.params_offsets_per_rank): + param.reshard_(writeback=writeback) + if writeback: + offset_start, offset_end = param_offsets[local_rank] + self.params_data.data[offset_start:offset_end].copy_(param.data) + param.data = self.params_data.data[offset_start:offset_end] + + def reduce_scatter_grads( + self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None + ): + for i, param in enumerate(self.params): + if (unsharded_grad := param.grad) is None: + continue + + if grad_reduce_dtype is not None: + unsharded_grad = unsharded_grad.to(dtype=grad_reduce_dtype) + + if grad_dtype is None: + grad_dtype = param.dtype + + # TODO: batch reductions together + + # Only NCCL supports 'reduce_scatter'. So with other backends we use 'all_reduce'. + if dist.get_backend() == dist.Backend.NCCL: + # Get chunks corresponding to each rank. + grad_chunks = param.chunk_unsharded(unsharded_grad, pad=True) + new_sharded_grad = torch.empty_like(grad_chunks[0]) + dist.reduce_scatter(new_sharded_grad, grad_chunks, group=self.process_group) + param.grad = new_sharded_grad[: param.unsharded_flattened_offsets[1]].to(dtype=grad_dtype) + else: + dist.all_reduce(unsharded_grad, group=self.process_group) + param.grad = param.sharded_chunk(unsharded_grad).detach().to(dtype=grad_dtype) + + del unsharded_grad + + if (cached_grad := self.grads[i]) is not None: + param.grad.add_(cached_grad) + self.grads[i] = None + del cached_grad diff --git a/src/olmo_core/distributed/fsdp/fsdp.py b/src/olmo_core/distributed/fsdp/fsdp.py index 6da8b935..f3d2202c 100644 --- a/src/olmo_core/distributed/fsdp/fsdp.py +++ b/src/olmo_core/distributed/fsdp/fsdp.py @@ -27,9 +27,10 @@ import torch.distributed as dist import torch.nn as nn -from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter -from olmo_core.utils import apply_to_tensors, get_default_device, get_grad_norm +from olmo_core.distributed.tensors import ShardedFlatParameter +from olmo_core.utils import apply_to_tensors, gc_cuda, get_default_device, get_grad_norm +from .flat_param_handle import FlatParamHandle from .state import FSDPState from .stream import Stream @@ -456,10 +457,8 @@ def _managed_named_parameters(self) -> Generator[Tuple[str, nn.Parameter], None, Returns a generator over all parameters managed by this FSDP instance. This is equivalent to `self.module.named_parameters()` except that parameters within nested FSDP instances are omitted. """ - for module_name, module in self._named_children(recurse=lambda m: not isinstance(m, FSDP)): - if not isinstance(module, FSDP): - for param_name, param in module.named_parameters(recurse=False): - yield f"{module_name}.{param_name}", param + for param_name, param in zip(self.state.flat_param_handle.param_fqns, self.state.flat_param_handle.params): + yield param_name, param def _fsdp_children(self, recurse: bool = False) -> Generator[FSDP, None, None]: """ @@ -474,22 +473,36 @@ def _fsdp_children(self, recurse: bool = False) -> Generator[FSDP, None, None]: @torch.no_grad() def _shard(self): """ - Shard the wrapped module in place, replacing each ``nn.Parameter`` with a ``ShardedFlatParameter``. + Shard the wrapped module in place, replacing each ``nn.Parameter`` with a ``ShardedFlatParameter``, + and then collecting all sharded flat param data into a single ``FlatParamHandle``. Afterwards + the sharded data in each sharded flat param will be a view into a single flat tensor managed + by the flat param handle. + This should only be called once at initialization. """ log.debug("Sharding %s...", self.module.__class__.__name__) - for _, m in self._named_children( - recurse=lambda m: not isinstance(m, FSDP) - ): # NOTE: this generator will include `self.module` itself - if isinstance(m, FSDP): + + params: List[ShardedFlatParameter] = [] + param_fqns: List[str] = [] + # NOTE: this generator will include `self.module` itself + for module_name, module in self._named_children(recurse=lambda m: not isinstance(m, FSDP)): + if isinstance(module, FSDP): continue - for param_name, param in m.named_parameters(recurse=False): - # TODO: use better sharding strategy that doesn't potentially always result in highest rank with - # smallest shard. + for param_name, param in module.named_parameters(recurse=False): sharded_flat_param = ShardedFlatParameter.shard( param, process_group=self.process_group, device=self.device, synchronize=False ) - setattr(m, param_name, sharded_flat_param) + setattr(module, param_name, sharded_flat_param) + params.append(sharded_flat_param) + param_fqns.append(f"{module_name}.{param_name}") + + # Collate the data from all flat params into the flat param handle. The data in each flat param + # will then just be a view into a slice of the data managed by the flat param handle. + # This makes unsharding more efficient as we'll only need a single `all_gather` call. + self.state.flat_param_handle = FlatParamHandle.collate_flat_params( + params, param_fqns, process_group=self.process_group, device=self.device + ) + gc_cuda() @torch.no_grad() def _unshard( @@ -515,19 +528,9 @@ def _unshard( # if root to respect the optimizer step and any other computations on the params outside of this # module's forward/backward pass. with self.state.unshard_stream(wait_stream=self.state.current_stream if self.is_root else None): - # TODO: batch the unshards for all params together? - for param_name, param in self._managed_named_parameters(): - if not isinstance(param, ShardedFlatParameter): - continue - - param.unshard_(dtype=self.precision.param_dtype if cast else None, rank0_only=rank0_only) - if cache_grads and param.grad is not None: - # We should only be caching these between the pre-backward and post-backward - # hooks. The post-backward hook will remove the cached grad as it accumulates - # it into persistent sharded grad. - assert param_name not in self.state.sharded_grad_cache - self.state.sharded_grad_cache[param_name] = param.grad.detach() - param.grad = None + self.state.flat_param_handle.unshard_( + self.precision.param_dtype if cast else None, rank0_only=rank0_only, cache_grads=cache_grads + ) if prefetch_from is not None: for module in self._deque_from(prefetch_from): @@ -551,12 +554,7 @@ def _reshard(self, writeback: bool = False, recurse: bool = False): self.state.params_prefetched = False with self.state.unshard_stream(wait_stream=self.state.compute_stream): - # TODO: batch the unshards for all params together? - for _, param in self._managed_named_parameters(): - if not isinstance(param, ShardedFlatParameter): - continue - - param.reshard_(writeback=writeback) + self.state.flat_param_handle.reshard_(writeback=writeback) if recurse: for module in self._fsdp_children(): @@ -584,40 +582,10 @@ def _reduce_scatter_grads(self): grad_reduce_dtype: Optional[torch.dtype] = self.precision.reduce_dtype or self.precision.param_dtype with self.state.reduce_stream(wait_stream=self.state.current_stream): - # TODO: batch the reductions for all params together? - for param_name, param in self._managed_named_parameters(): - if (unsharded_grad := param.grad) is None: - continue - - log.debug("Reduce-scattering grads for %s.%s...", self.module.__class__.__name__, param_name) - - if grad_reduce_dtype is not None: - unsharded_grad = unsharded_grad.to(dtype=grad_reduce_dtype) - - if not isinstance(param, ShardedFlatParameter): - dist.all_reduce(unsharded_grad, group=self.process_group) - param.grad = unsharded_grad - continue - - if grad_dtype is None: - grad_dtype = param.dtype - - # Only NCCL supports 'reduce_scatter'. So with other backends we use 'all_reduce'. - if dist.get_backend() == dist.Backend.NCCL: - # Get chunks corresponding to each rank. - grad_chunks = param.chunk_unsharded(unsharded_grad, pad=True) - new_sharded_grad = torch.empty_like(grad_chunks[0]) - dist.reduce_scatter(new_sharded_grad, grad_chunks, group=self.process_group) - param.grad = new_sharded_grad[: param.unsharded_flattened_offsets[1]].to(dtype=grad_dtype) - else: - dist.all_reduce(unsharded_grad, group=self.process_group) - param.grad = param.sharded_chunk(unsharded_grad).detach().to(dtype=grad_dtype) - - del unsharded_grad - - if (cached_grad := self.state.sharded_grad_cache.pop(param_name, None)) is not None: - param.grad.add_(cached_grad) - del cached_grad + log.debug("Reduce-scattering grads for %s", self.module.__class__.__name__) + self.state.flat_param_handle.reduce_scatter_grads( + grad_dtype=grad_dtype, grad_reduce_dtype=grad_reduce_dtype + ) def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None]: count = 0 diff --git a/src/olmo_core/distributed/fsdp/state.py b/src/olmo_core/distributed/fsdp/state.py index 8966be4c..819339b0 100644 --- a/src/olmo_core/distributed/fsdp/state.py +++ b/src/olmo_core/distributed/fsdp/state.py @@ -9,6 +9,7 @@ from olmo_core.utils import get_default_device +from .flat_param_handle import FlatParamHandle from .stream import Stream if TYPE_CHECKING: @@ -22,6 +23,11 @@ class FSDPState: The device the FSDP node is running on. """ + flat_param_handle: FlatParamHandle = field(default_factory=FlatParamHandle) + """ + Manages the shared data for all sharded flat params. + """ + pre_backward_hook_handles: List[RemovableHandle] = field(default_factory=list) """ Backward hooks registered to the output tensors from the wrapped module's forward method. @@ -33,12 +39,6 @@ class FSDPState: The keys are parameter FQNs. """ - sharded_grad_cache: Dict[str, torch.Tensor] = field(default_factory=dict) - """ - For caching sharded gradients during gradient accumulation. - Maps param FQNs to the corresponding local sharded gradient. - """ - lazy_init_complete: bool = False """ Marked true when final initialization runs lazily during the first forward pass. diff --git a/src/olmo_core/distributed/tensors/__init__.py b/src/olmo_core/distributed/tensors/__init__.py new file mode 100644 index 00000000..2f9c7840 --- /dev/null +++ b/src/olmo_core/distributed/tensors/__init__.py @@ -0,0 +1,8 @@ +""" +Distributed tensor and parameter classes. +""" + +from .sharded_flat_parameter import ShardedFlatParameter +from .sharded_flat_tensor import ShardedFlatTensor, ShardingSpec + +__all__ = ["ShardedFlatTensor", "ShardedFlatParameter", "ShardingSpec"] diff --git a/src/olmo_core/distributed/sharded_flat_parameter.py b/src/olmo_core/distributed/tensors/sharded_flat_parameter.py similarity index 93% rename from src/olmo_core/distributed/sharded_flat_parameter.py rename to src/olmo_core/distributed/tensors/sharded_flat_parameter.py index 73cb0dfb..d599fcf1 100644 --- a/src/olmo_core/distributed/sharded_flat_parameter.py +++ b/src/olmo_core/distributed/tensors/sharded_flat_parameter.py @@ -11,6 +11,10 @@ class ShardedFlatParameter(ShardedFlatTensor, nn.Parameter): + """ + A :class:`~torch.nn.parameter.Parameter` version of :class:`ShardedFlatTensor`. + """ + def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> ShardedFlatParameter: if data is not None and data.ndim != 1: raise ValueError(f"{cls.__name__} requires flat data! Got {data.shape}") diff --git a/src/olmo_core/distributed/sharded_flat_tensor.py b/src/olmo_core/distributed/tensors/sharded_flat_tensor.py similarity index 93% rename from src/olmo_core/distributed/sharded_flat_tensor.py rename to src/olmo_core/distributed/tensors/sharded_flat_tensor.py index 48c5731e..d7c979c9 100644 --- a/src/olmo_core/distributed/sharded_flat_tensor.py +++ b/src/olmo_core/distributed/tensors/sharded_flat_tensor.py @@ -3,18 +3,21 @@ import math from dataclasses import dataclass from functools import reduce -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Type, TypeVar import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from .utils import get_rank, get_world_size +from ..utils import get_rank, get_world_size __all__ = ["ShardedFlatTensor", "ShardingSpec"] +T = TypeVar("T", bound="ShardedFlatTensor") + + @dataclass class ShardingSpec: unsharded_shape: Tuple[int, ...] @@ -125,14 +128,14 @@ def _gather_data(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F @classmethod def shard( - cls, + cls: Type[T], tensor: torch.Tensor, sharding_spec: Optional[ShardingSpec] = None, process_group: Optional[dist.ProcessGroup] = None, synchronize: bool = True, device: Optional[torch.device] = None, requires_grad: Optional[bool] = None, - ) -> ShardedFlatTensor: + ) -> T: """ Shard a tensor across a process group. """ @@ -171,26 +174,35 @@ def shard( else: sharded_tensor = torch.empty(offsets[1] - offsets[0], device=device, dtype=tensor.dtype) - sharded_param = cls( # type: ignore + sharded_tensor = cls( # type: ignore sharded_tensor, requires_grad=requires_grad if requires_grad is not None else tensor.requires_grad ) - sharded_param.mark_as_sharded(sharding_spec, process_group=process_group) - return sharded_param + sharded_tensor.mark_as_sharded(sharding_spec, process_group=process_group) + return sharded_tensor - def gather(self, dtype: Optional[torch.dtype] = None) -> nn.Parameter: + def gather(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False) -> torch.Tensor: """ Gather the sharded flat parameter across a process group into a full unsharded parameter. """ - unsharded_data = self._gather_data(dtype=dtype) - return nn.Parameter(unsharded_data, requires_grad=self.requires_grad) - - def unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False): + unsharded_data = self._gather_data(dtype=dtype, rank0_only=rank0_only) + unsharded_data.requires_grad = self.requires_grad + return unsharded_data + + def unshard_( + self, + unsharded_data: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + rank0_only: bool = False, + ): """ Unshard this parameter's data in-place. You should generally call :meth:`reshard_()` afterwards. If ``rank0_only=True``, non rank 0 processes will have an empty tensor in their data. """ - unsharded_data = self._gather_data(dtype=dtype, rank0_only=rank0_only) + if unsharded_data is None: + unsharded_data = self._gather_data(dtype=dtype, rank0_only=rank0_only) + elif not rank0_only or get_rank(self.process_group) == 0: + unsharded_data = unsharded_data.view(*self.unsharded_shape) self._set_metadata(self.SHARDED_FLAT_TENSOR_CACHED_SHARDED_DATA_KEY, self.data) self.data = unsharded_data diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index 0a85cbaa..71c92a60 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -1,4 +1,5 @@ import dataclasses +import gc import os import time from enum import Enum @@ -163,3 +164,21 @@ def get_grad_norm(params: Iterable[nn.Parameter], norm_type: float) -> torch.Ten dtype=torch.float32, ) return grad_norm + + +def same_storage(x: torch.Tensor, y: torch.Tensor) -> bool: + """ + Check if two tensors share the same storage. + """ + x_ptrs = set(e.data_ptr() for e in x.view(-1)) + y_ptrs = set(e.data_ptr() for e in y.view(-1)) + return (x_ptrs <= y_ptrs) or (y_ptrs <= x_ptrs) + + +def gc_cuda(): + """ + Run CUDA garbage collection. + """ + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 866e0093..400fc6a1 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -20,8 +20,11 @@ unshard_optim_state, ) from olmo_core.distributed.fsdp import FSDP -from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter -from olmo_core.distributed.sharded_flat_tensor import ShardedFlatTensor, ShardingSpec +from olmo_core.distributed.tensors import ( + ShardedFlatParameter, + ShardedFlatTensor, + ShardingSpec, +) from .utils import ( BACKENDS, diff --git a/src/test/distributed/fsdp/flat_param_handle_test.py b/src/test/distributed/fsdp/flat_param_handle_test.py new file mode 100644 index 00000000..8f43dd71 --- /dev/null +++ b/src/test/distributed/fsdp/flat_param_handle_test.py @@ -0,0 +1,54 @@ +import pytest +import torch +import torch.distributed as dist + +from olmo_core.distributed.fsdp.flat_param_handle import FlatParamHandle +from olmo_core.distributed.tensors import ShardedFlatParameter +from olmo_core.utils import same_storage + +from ..utils import BACKENDS, get_default_device, run_distributed_test + + +def run_flat_param_handle_collate_flat_params(): + all_og_data = [ + torch.rand(2, 3, device=get_default_device()), + torch.rand(4, 8, device=get_default_device()), + torch.rand(7, device=get_default_device()), + ] + for og_data in all_og_data: + dist.all_reduce(og_data) + + flat_params = [ShardedFlatParameter.shard(og_data) for og_data in all_og_data] + handle = FlatParamHandle.collate_flat_params(flat_params, ["x", "y", "z"], device=get_default_device()) + for param in handle.params: + assert same_storage(param, handle.params_data) + + # Unshard all params. + handle.unshard_() + for og_data, param in zip(all_og_data, handle.params): + assert not param.is_sharded + torch.testing.assert_close(param.data, og_data) + + # Reshard all params. + handle.reshard_() + for param in handle.params: + assert param.is_sharded + assert same_storage(param, handle.params_data) + + # Updated the data in a param should update the data in the handle, since the data in the + # param is just a view into the data in the handle. + handle.params[0].fill_(torch.tensor(0.0, device=get_default_device())) + assert (handle.params_data[0 : handle.params[0].numel()] == 0).all() + handle.unshard_() + assert (handle.params[0] == 0).all() + handle.params[0].fill_(torch.tensor(1.0, device=get_default_device())) + handle.reshard_(writeback=True) + assert (handle.params[0] == 1).all() + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_flat_param_handle_collate_flat_params(backend): + run_distributed_test( + run_flat_param_handle_collate_flat_params, + backend=backend, + ) diff --git a/src/test/distributed/fsdp/fsdp_test.py b/src/test/distributed/fsdp/fsdp_test.py index 07a3f70b..9006f870 100644 --- a/src/test/distributed/fsdp/fsdp_test.py +++ b/src/test/distributed/fsdp/fsdp_test.py @@ -7,7 +7,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from olmo_core.distributed.fsdp import FSDP, FSDPDebugConfig -from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter +from olmo_core.distributed.tensors import ShardedFlatParameter from ..utils import ( BACKENDS, @@ -43,6 +43,11 @@ def run_fsdp_against_non_distributed_model(model_factory, model_data_factory): with torch.no_grad(): torch.testing.assert_close(param.data, param.sharded_chunk(model.state_dict()[name])) + with fsdp.summon_full_params(): + for name, param in fsdp.module.named_parameters(): + with torch.no_grad(): + torch.testing.assert_close(param.data, model.state_dict()[name]) + # Run forward/backward pass on non-distributed model and collect grads for comparison. expected_grads = {} loss = model(model_data).sum() @@ -154,7 +159,8 @@ def run_fsdp_against_ddp(model_factory, model_data_factory): # Since we've only done a single backwards pass (no grad accumulation), there shouldn't # be any cached gradients. - assert not fsdp_model.state.sharded_grad_cache + for cached_grad in fsdp_model.state.flat_param_handle.grads: + assert cached_grad is None # Run optimizer step. optim.step() @@ -265,6 +271,20 @@ def forward(self, x): "inner.fc.4.bias", }, param_names + assert set(fsdp.state.flat_param_handle.param_fqns) == { + "out.weight", + "out.bias", + } + + assert set(fsdp.module.inner.state.flat_param_handle.param_fqns) == { + "fc.0.weight", + "fc.0.bias", + "fc.2.weight", + "fc.2.bias", + "fc.4.weight", + "fc.4.bias", + } + buf_names = set(n for n, _ in fsdp.named_buffers()) assert buf_names == {"buf"} diff --git a/src/test/distributed/tensors/__init__.py b/src/test/distributed/tensors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/test/distributed/sharded_flat_parameter_test.py b/src/test/distributed/tensors/sharded_flat_parameter_test.py similarity index 94% rename from src/test/distributed/sharded_flat_parameter_test.py rename to src/test/distributed/tensors/sharded_flat_parameter_test.py index 5b0d06c4..e768ee6b 100644 --- a/src/test/distributed/sharded_flat_parameter_test.py +++ b/src/test/distributed/tensors/sharded_flat_parameter_test.py @@ -4,10 +4,13 @@ import torch import torch.distributed as dist -from olmo_core.distributed.sharded_flat_parameter import ShardedFlatParameter -from olmo_core.distributed.sharded_flat_tensor import ShardedFlatTensor, ShardingSpec +from olmo_core.distributed.tensors.sharded_flat_parameter import ShardedFlatParameter +from olmo_core.distributed.tensors.sharded_flat_tensor import ( + ShardedFlatTensor, + ShardingSpec, +) -from .utils import BACKENDS, INIT_DEVICES, get_default_device, run_distributed_test +from ..utils import BACKENDS, INIT_DEVICES, get_default_device, run_distributed_test def test_init_empty_sharded_parameter(): diff --git a/src/test/distributed/sharded_flat_tensor_test.py b/src/test/distributed/tensors/sharded_flat_tensor_test.py similarity index 96% rename from src/test/distributed/sharded_flat_tensor_test.py rename to src/test/distributed/tensors/sharded_flat_tensor_test.py index 40eb82a2..65ae3bbd 100644 --- a/src/test/distributed/sharded_flat_tensor_test.py +++ b/src/test/distributed/tensors/sharded_flat_tensor_test.py @@ -4,9 +4,12 @@ import torch import torch.distributed as dist -from olmo_core.distributed.sharded_flat_tensor import ShardedFlatTensor, ShardingSpec +from olmo_core.distributed.tensors.sharded_flat_tensor import ( + ShardedFlatTensor, + ShardingSpec, +) -from .utils import BACKENDS, INIT_DEVICES, get_default_device, run_distributed_test +from ..utils import BACKENDS, INIT_DEVICES, get_default_device, run_distributed_test def test_init_sharded():