Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 17, 2024
1 parent 4ed1cb4 commit 71b1aee
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def shard_params(

def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = False, set_grads: bool = False):
"""
Allocate the unsharded, padded data prior to the all-gather.
Allocate the unsharded, padded data prior to the all-gather. Ideally this should be called
in a separate stream from :meth:`.unshard_()` for better throughput.
"""
self._ran_pre_unshard = True

Expand Down Expand Up @@ -272,6 +273,8 @@ def unshard_(
if not self._ran_pre_unshard:
self.pre_unshard_(dtype=dtype, rank0_only=rank0_only, set_grads=set_grads)
else:
# The following tensors were potentially created in a different stream, so we need
# to make sure they're not deallocated prematurely.
Stream.current(self.device).record_for(self.params_data.data)
if self.params_sharded_data_lp is not None:
Stream.current(self.device).record_for(self.params_sharded_data_lp)
Expand Down

0 comments on commit 71b1aee

Please sign in to comment.