Skip to content

Commit

Permalink
Remove monkeypatch and new state dict APIs for torch 2.2 (mosaicml#2899)
Browse files Browse the repository at this point in the history
* fix mosaicfsdp

* bump to 2.3

* remove init
  • Loading branch information
mvpatel2000 authored Jan 24, 2024
1 parent cfc439a commit 704c07e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 22 deletions.
8 changes: 4 additions & 4 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def get_model_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the model.
"""
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -909,7 +909,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the optimizer.
"""
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def load_model_state(
model_on_rank = state_dict['model'] is not None

if model_on_rank:
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
set_model_state_dict(
model=self.model,
Expand Down Expand Up @@ -1277,7 +1277,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True):
strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer
state dict should perfectly match the keys in the optimizer instance.
"""
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict
optimizer = self.optimizers[0]
set_optimizer_state_dict(
Expand Down
10 changes: 0 additions & 10 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,6 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Monkeypatch dtensor support
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0
FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore

# Monkeypath state_dict
from torch.distributed.checkpoint import state_dict # type: ignore

from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0
state_dict._verify_options = _verify_options_t2p2p0

elif version.parse(torch.__version__) < version.parse('2.3.1'):
# Monkey patch for torch < 2.3.1 ie torch == 2.3.0
# Note: this is the same patch as 2.2.0, we are just making a new if branch
Expand Down
5 changes: 4 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,10 @@ def __init__(
assert not isinstance(device_train_microbatch_size, str)

# Distributed
dist.initialize_dist(device, dist_timeout)
if deepspeed_config is not None or fsdp_config is not None or dist.get_world_size() > 1:
# Deepspeed and FSDP both require torch.distributed to be initialized, even if the world size is 1
# And torch.distributed is always required for multi-rank training
dist.initialize_dist(device, dist_timeout)

# Reproducibility
rank_zero_seed, seed = _distribute_and_get_random_seed(seed, device)
Expand Down
14 changes: 7 additions & 7 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
else:
cur_state_dict = state.state_dict()
# For older versions of torch, we load optimizer separately.
if version.parse(torch.__version__) < version.parse('2.1.3'):
if version.parse(torch.__version__) < version.parse('2.2.9'):
cur_state_dict.pop('optimizers')
state_dict: Dict[str, Any] = {
'state': cur_state_dict,
Expand Down Expand Up @@ -523,7 +523,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
else:
expect_file = True

if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
dist_cp.load( # type: ignore
state_dict=state_dict,
storage_reader=storage_reader,
Expand All @@ -547,8 +547,8 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
)

# 2. Optionally load optimizer
# if we are using later than 2.1.0 then optimizer will already be loaded
if version.parse(torch.__version__) < version.parse('2.1.3') and not load_weights_only:
# if we are using later than 2.2.9 then optimizer will already be loaded
if version.parse(torch.__version__) < version.parse('2.2.9') and not load_weights_only:
optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'],
optimizer_key='optimizers',
storage_reader=storage_reader)
Expand Down Expand Up @@ -956,12 +956,12 @@ def _save_checkpoint(
state_dict['state'] = state_dict.get('state', {})

if state.fsdp_sharded_state_dict_enabled:
# To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top
# To load optimizer states with 2.0 <= torch < 2.2.9 , the optimizer state must be at the top
# level of the state dict because the load_sharded_optimizer_state_dict function
# requires a top level state dict key for the optimizer.
# See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271
# for more info.
if using_torch_2() and version.parse(torch.__version__) < version.parse('2.1.3'):
if using_torch_2() and version.parse(torch.__version__) < version.parse('2.2.9'):
if not weights_only:
state_dict['optimizers'] = state_dict['state'].pop('optimizers')
log.debug('State dict created.')
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def _save_checkpoint(
expect_file = True

if expect_file:
if version.parse(torch.__version__) > version.parse('2.1.3'):
if version.parse(torch.__version__) > version.parse('2.2.9'):
dist_cp.save( # type: ignore
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(dirname),
Expand Down

0 comments on commit 704c07e

Please sign in to comment.