Skip to content

Commit

Permalink
Use FSDP CustomPolicy to support custom kwargs passed to different wr…
Browse files Browse the repository at this point in the history
…apped modules (#2585)

* wip

* formatting

* cleanup

* add torch version check

* remove patching torch 2.1.1

* refactor

* adding back monkey patching chucksharding

* remove 2.1 patch util functions

* add type ignore

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
cli99 and mvpatel2000 authored Oct 3, 2023
1 parent 67c7819 commit e2386b3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 347 deletions.
76 changes: 49 additions & 27 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

log = logging.getLogger(__name__)

process_group_cache = {}


class DDPSyncStrategy(StringEnum):
"""How and when gradient synchronization should happen.
Expand Down Expand Up @@ -437,36 +439,56 @@ def _param_init_fn(module: torch.nn.Module) -> None:
'This leaves parameters without initialization. Please add a ``param_init_fn`` or ``reset_parameters`` '
f'to module `{obj_name}`.')

# Choose which modules to FSDP wrap according to the following priority:
# If module has attribute `module._fsdp_wrap = ...`, always respect it
# Otherwise wrap if root object `obj.fsdp_wrap_fn(module)` is true.
def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
if recurse:
return True
should_be_wrapped = False
if hasattr(module, '_fsdp_wrap'):
should_be_wrapped = bool(module._fsdp_wrap)
elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
should_be_wrapped = obj.fsdp_wrap_fn(module)

if should_be_wrapped and auto_microbatching:
module.register_forward_hook(sync_hook)
module.register_full_backward_hook(sync_hook)
return should_be_wrapped

if is_torch_2_0:

def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
return __auto_wrap_policy(module, recurse, nonwrapped_numel)

_auto_wrap_policy = _auto_wrap_policy_new

if version.parse(torch.__version__) > version.parse('2.1.0.dev'):
# CustomPolicy is only supported in torch v2.1.0-rc1 or higher
from torch.distributed.fsdp.wrap import CustomPolicy # type: ignore

def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]:
ret = False
if hasattr(module, '_fsdp_wrap'):
ret = bool(module._fsdp_wrap)
elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
ret = obj.fsdp_wrap_fn(module)
from composer.trainer.mosaic_fsdp_utils import _set_custom_fsdp_module_kwargs
if isinstance(ret, dict):
ret = _set_custom_fsdp_module_kwargs(ret, process_group_cache)
if ret and auto_microbatching:
module.register_forward_hook(sync_hook)
module.register_full_backward_hook(sync_hook)
return ret

_auto_wrap_policy = CustomPolicy(lambda_fn)
else:
# Choose which modules to FSDP wrap according to the following priority:
# If module has attribute `module._fsdp_wrap = ...`, always respect it
# Otherwise wrap if root object `obj.fsdp_wrap_fn(module)` is true.
def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
if recurse:
return True
should_be_wrapped = False
if hasattr(module, '_fsdp_wrap'):
should_be_wrapped = bool(module._fsdp_wrap)
elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
should_be_wrapped = obj.fsdp_wrap_fn(module)

if should_be_wrapped and auto_microbatching:
module.register_forward_hook(sync_hook)
module.register_full_backward_hook(sync_hook)
return should_be_wrapped

if is_torch_2_0:

def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
return __auto_wrap_policy(module, recurse, nonwrapped_numel)

_auto_wrap_policy = _auto_wrap_policy_new

else:

def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_params: int) -> bool:
return __auto_wrap_policy(module, recurse, unwrapped_params)
def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_params: int) -> bool:
return __auto_wrap_policy(module, recurse, unwrapped_params)

_auto_wrap_policy = _auto_wrap_policy_old
_auto_wrap_policy = _auto_wrap_policy_old

fsdp_obj = FullyShardedDataParallel(
obj,
Expand Down
4 changes: 0 additions & 4 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ def patch_pytorch():
elif version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey path for torch < 2.1.1 ie torch == 2.1.0

# Monkey patch __init__ where __init__ calls the custom _auto_wrap fn
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p1p0
FullyShardedDataParallel.__init__ = init_fn_t2p1p0 # type: ignore

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata
ChunkShardingSpec.shard = shard
Expand Down
Loading

0 comments on commit e2386b3

Please sign in to comment.