Skip to content

Commit

Permalink
Fix import for daily test (mosaicml#2826)
Browse files Browse the repository at this point in the history
* patched torch

* fixed torch imports

* fixed torch imports

* fixed torch imports

* patching through composer

* patching through composer

* patching typingr

* comment added

* don't patch torch 2.1.0

* patch torch 2.1.1 and 2.2.0

* linting fix

* waiting on computation stream from unshard stream

* waiting on computation stream from unshard stream

* less waiting

* no waiting

* all unshard streams wait on computation stream now

* 2.2.0 dev change

* correct waiting on computation stream

* fsdp state typiung

* patching root pre forward

* patching root pre forward

* fsdp state typing

* patch forward

* correct waiting

* linting

* daily test change

* daily test fix
  • Loading branch information
snarayan21 authored Jan 8, 2024
1 parent 23bc6fb commit a36fb74
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from packaging import version
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata,
custom_auto_wrap_t1p13p1)
Expand Down Expand Up @@ -61,8 +62,6 @@ def patch_pytorch():
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Better overlap communication and computation
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p1,
_wait_for_computation_stream, forward)
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1
Expand All @@ -75,7 +74,6 @@ def patch_pytorch():

# Better overlap communication and computation
from torch.distributed.fsdp import _runtime_utils
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2,
_wait_for_computation_stream, forward)
Expand Down

0 comments on commit a36fb74

Please sign in to comment.