Skip to content

Commit

Permalink
Fix a CUDA synchronization issue with FSDP (#19)
Browse files Browse the repository at this point in the history
* record compute stream with unsharded data

* add explanation comment
  • Loading branch information
epwalsh authored May 15, 2024
1 parent fb20119 commit 7a2299a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def unshard_(
self.params_sharded_data_lp = None
del local_shard

# Set the data for each param as a view into `all_params_unsharded_data`.
# Set the data for each param as a view into `self.params_data`.
offset = 0
for param in self.params:
if rank0_only and local_rank != 0:
Expand Down
7 changes: 7 additions & 0 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,13 @@ def _unshard(
set_grads=set_grads,
)

# Record the current stream for the unsharded parameter data to make sure it's not
# deallocated prematurely.
for handle in self.state.flat_param_handles:
self.state.current_stream.record_for(handle.params_data)
if handle.params_unsharded_grad is not None:
self.state.current_stream.record_for(handle.params_unsharded_grad)

if recurse:
for module in self._fsdp_children():
module._unshard(**kwargs)
Expand Down

0 comments on commit 7a2299a

Please sign in to comment.