From c8f1385b70d6490739e59dd7c62c95a0b75e7973 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 30 Jan 2024 15:51:31 -0500 Subject: [PATCH] Torch 2.2 Support (#2930) * lint * rename * lint * type ignore * allow empty --- composer/trainer/mosaic_fsdp.py | 19 +++++++++----- composer/trainer/mosaic_fsdp_utils.py | 26 +++++++++---------- composer/trainer/trainer.py | 2 +- composer/utils/fx_utils.py | 1 + .../test_in_context_learning_datasets.py | 8 +++--- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 1b346e92e4..7b9dc83e59 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -12,9 +12,6 @@ from torch.distributed._shard.sharding_spec import ChunkShardingSpec from torch.distributed.fsdp import FullyShardedDataParallel -from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata, - custom_auto_wrap_t1p13p1) - def patch_pytorch(): """Monkey patches pytorch functions based on pytorch version.""" @@ -25,6 +22,7 @@ def patch_pytorch(): # Monkey patch for torch < 2.0 ie torch == 1.13.1 # Monkey patch _auto_wrap with _custom_auto_wrap fn + from composer.trainer.mosaic_fsdp_utils import custom_auto_wrap_t1p13p1 FullyShardedDataParallel._auto_wrap = custom_auto_wrap_t1p13p1 # type: ignore elif version.parse(torch.__version__) < version.parse('2.0.1'): @@ -39,16 +37,23 @@ def patch_pytorch(): FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore # Monkey patch sharding method + from composer.trainer.mosaic_fsdp_utils import build_metadata + ChunkShardingSpec.build_metadata = build_metadata elif version.parse(torch.__version__) < version.parse('2.1.1'): # Monkey patch for torch < 2.1.1 ie torch == 2.1.0 # Monkey patch sharding method + from composer.trainer.mosaic_fsdp_utils import build_metadata + ChunkShardingSpec.build_metadata = build_metadata # Monkey patch partial state dict handling from torch.distributed.fsdp import _state_dict_utils + + from composer.trainer.mosaic_fsdp_utils import _sharded_pre_load_state_dict_hook + _state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook) # Allow 2D HSDP @@ -79,14 +84,14 @@ def patch_pytorch(): _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None # Monkeypath state_dict - from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0 - FullyShardedDataParallel.__init__ = init_fn_t2p2p0 + from composer.trainer.mosaic_fsdp_utils import init_fn_t2p3p0 + FullyShardedDataParallel.__init__ = init_fn_t2p3p0 # Monkeypath state_dict from torch.distributed.checkpoint import state_dict # type: ignore - from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0 - state_dict._verify_options = _verify_options_t2p2p0 + from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p3p0 + state_dict._verify_options = _verify_options_t2p3p0 # Monkeypatch sharding optim state from torch.distributed.fsdp import _optim_utils diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index af518d8a4b..448a04f80a 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -762,7 +762,7 @@ def _sharded_pre_load_state_dict_hook( _enter_unshard_params_ctx(module, fsdp_state, writeback=True) -if version.parse(torch.__version__) > version.parse('2.1.3') and version.parse( +if version.parse(torch.__version__) > version.parse('2.2.9') and version.parse( torch.__version__) < version.parse('2.3.1'): import copy @@ -787,7 +787,7 @@ def _sharded_pre_load_state_dict_hook( from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy, _Policy from torch.distributed.tensor.parallel.fsdp import DTensorExtensions - def all_gather_dtensor_t2p2p0( + def all_gather_dtensor_t2p3p0( self, tensor: DTensor, parent_mesh: Optional[DeviceMesh], @@ -806,7 +806,7 @@ def all_gather_dtensor_t2p2p0( ) return tensor.to_local() - def chunk_dtensor_t2p2p0( + def chunk_dtensor_t2p3p0( self, tensor: torch.Tensor, rank: int, @@ -869,10 +869,10 @@ def chunk_dtensor_t2p2p0( placements=shard_placements, ) - DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p2p0 - DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p2p0 + DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p3p0 + DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p3p0 - def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: + def _is_valid_hybrid_shard_device_mesh_t2p3p0(device_mesh: DeviceMesh) -> bool: #parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) #if parent_mesh is not None: # raise RuntimeError( @@ -881,13 +881,13 @@ def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: # ) return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2 - def _init_process_group_state_for_hybrid_shard_t2p2p0( + def _init_process_group_state_for_hybrid_shard_t2p3p0( state: _FSDPState, process_group: ProcessGroupType, device_mesh: DeviceMesh, ) -> _FSDPState: if device_mesh: - if _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh): + if _is_valid_hybrid_shard_device_mesh_t2p3p0(device_mesh): state._device_mesh = device_mesh # We currently only allow _inter_node_pg to be the outermost dimension, and the # process_group(intra_node) to be the innermost dimension. @@ -917,7 +917,7 @@ def _init_process_group_state_for_hybrid_shard_t2p2p0( state._inter_node_state = _get_default_comm_hook_state(process_group=state._inter_node_pg,) return state - def _init_process_group_state_t2p2p0( + def _init_process_group_state_t2p3p0( state: _FSDPState, process_group: ProcessGroupType, sharding_strategy: ShardingStrategy, @@ -938,7 +938,7 @@ def _init_process_group_state_t2p2p0( 'requires explicit specification of process group or device_mesh.', ) else: - state = _init_process_group_state_for_hybrid_shard_t2p2p0(state, process_group, device_mesh) + state = _init_process_group_state_for_hybrid_shard_t2p3p0(state, process_group, device_mesh) else: if device_mesh: state._device_mesh = device_mesh @@ -956,7 +956,7 @@ def _init_process_group_state_t2p2p0( state._gradient_postdivide_factor = (data_parallel_world_size / state._gradient_predivide_factor) return state - def init_fn_t2p2p0( + def init_fn_t2p3p0( self, module: nn.Module, process_group: ProcessGroupType = None, @@ -990,7 +990,7 @@ def init_fn_t2p2p0( # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up # the same process group state as the root FSDP module. self._device_mesh = device_mesh - _init_process_group_state_t2p2p0( + _init_process_group_state_t2p3p0( self, process_group, sharding_strategy, @@ -1064,7 +1064,7 @@ def init_fn_t2p2p0( from torch.distributed.checkpoint.state_dict import StateDictOptions, _StateDictInfo - def _verify_options_t2p2p0( + def _verify_options_t2p3p0( model: nn.Module, optims: Tuple[torch.optim.Optimizer, ...], optim_only: bool, diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 80a519d758..b144228f42 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2531,7 +2531,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int, microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_num_samples / current_batch_size) if use_grad_scaling: - microbatch_loss = cast(torch.Tensor, self.state.scaler.scale(microbatch_loss)) + microbatch_loss = cast(torch.Tensor, self.state.scaler.scale(microbatch_loss)) # type: ignore if self.state.deepspeed_enabled: self.state.deepspeed_model.backward(microbatch_loss) diff --git a/composer/utils/fx_utils.py b/composer/utils/fx_utils.py index 2b1ff41b3e..9162b84878 100644 --- a/composer/utils/fx_utils.py +++ b/composer/utils/fx_utils.py @@ -234,6 +234,7 @@ def apply_stochastic_residual(gm: GraphModule, drop_rate: float = 0.2) -> Tuple[ f'Input to apply_stochastic_residual should be an instance of GraphModule. Received {type(gm)}') all_tags, count = _tag_residual_nodes(gm) split_gm = split_by_tags(gm, all_tags) + assert isinstance(split_gm, GraphModule) for node in split_gm.graph.nodes: if node.op != 'call_module': continue diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index de2114dd76..0373a832b8 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -220,13 +220,13 @@ def test_stop_sequences_criteria(tiny_gpt2_tokenizer): seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] seq1 = [tiny_gpt2_tokenizer.pad_token_id] * (len(seq2) - len(seq1)) + seq1 input_ids = torch.LongTensor([seq1, seq2]) - assert not eos_criteria(input_ids, None) + assert not eos_criteria(input_ids, None) # type: ignore eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] input_ids = torch.LongTensor([seq1, seq2]) - assert eos_criteria(input_ids, None) + assert eos_criteria(input_ids, None) # type: ignore def test_stop_sequences_criteria_sentencepiece(tiny_llama_tokenizer): @@ -238,13 +238,13 @@ def test_stop_sequences_criteria_sentencepiece(tiny_llama_tokenizer): seq2 = tokenizer('Dogs are furry\n\n')['input_ids'] seq1 = [tokenizer.eos_token_id] * (len(seq2) - len(seq1)) + seq1 input_ids = torch.LongTensor([seq1, seq2]) - assert not eos_criteria(input_ids, None) + assert not eos_criteria(input_ids, None) # type: ignore eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2) seq1 = tokenizer('Dogs are furry\n\n')['input_ids'] seq2 = tokenizer('Dogs are furry\n\n')['input_ids'] input_ids = torch.LongTensor([seq1, seq2]) - assert eos_criteria(input_ids, None) + assert eos_criteria(input_ids, None) # type: ignore @pytest.mark.filterwarnings(