Skip to content

Commit

Permalink
Torch 2.2 Support (mosaicml#2930)
Browse files Browse the repository at this point in the history
* lint

* rename

* lint

* type ignore

* allow empty
  • Loading branch information
mvpatel2000 authored Jan 30, 2024
1 parent 5e2c3eb commit c8f1385
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 25 deletions.
19 changes: 12 additions & 7 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata,
custom_auto_wrap_t1p13p1)


def patch_pytorch():
"""Monkey patches pytorch functions based on pytorch version."""
Expand All @@ -25,6 +22,7 @@ def patch_pytorch():
# Monkey patch for torch < 2.0 ie torch == 1.13.1

# Monkey patch _auto_wrap with _custom_auto_wrap fn
from composer.trainer.mosaic_fsdp_utils import custom_auto_wrap_t1p13p1
FullyShardedDataParallel._auto_wrap = custom_auto_wrap_t1p13p1 # type: ignore

elif version.parse(torch.__version__) < version.parse('2.0.1'):
Expand All @@ -39,16 +37,23 @@ def patch_pytorch():
FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore

# Monkey patch sharding method
from composer.trainer.mosaic_fsdp_utils import build_metadata

ChunkShardingSpec.build_metadata = build_metadata

elif version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey patch for torch < 2.1.1 ie torch == 2.1.0

# Monkey patch sharding method
from composer.trainer.mosaic_fsdp_utils import build_metadata

ChunkShardingSpec.build_metadata = build_metadata

# Monkey patch partial state dict handling
from torch.distributed.fsdp import _state_dict_utils

from composer.trainer.mosaic_fsdp_utils import _sharded_pre_load_state_dict_hook

_state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook)

# Allow 2D HSDP
Expand Down Expand Up @@ -79,14 +84,14 @@ def patch_pytorch():
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Monkeypath state_dict
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0
FullyShardedDataParallel.__init__ = init_fn_t2p2p0
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p3p0
FullyShardedDataParallel.__init__ = init_fn_t2p3p0

# Monkeypath state_dict
from torch.distributed.checkpoint import state_dict # type: ignore

from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0
state_dict._verify_options = _verify_options_t2p2p0
from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p3p0
state_dict._verify_options = _verify_options_t2p3p0

# Monkeypatch sharding optim state
from torch.distributed.fsdp import _optim_utils
Expand Down
26 changes: 13 additions & 13 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def _sharded_pre_load_state_dict_hook(
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)


if version.parse(torch.__version__) > version.parse('2.1.3') and version.parse(
if version.parse(torch.__version__) > version.parse('2.2.9') and version.parse(
torch.__version__) < version.parse('2.3.1'):
import copy

Expand All @@ -787,7 +787,7 @@ def _sharded_pre_load_state_dict_hook(
from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy, _Policy
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions

def all_gather_dtensor_t2p2p0(
def all_gather_dtensor_t2p3p0(
self,
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
Expand All @@ -806,7 +806,7 @@ def all_gather_dtensor_t2p2p0(
)
return tensor.to_local()

def chunk_dtensor_t2p2p0(
def chunk_dtensor_t2p3p0(
self,
tensor: torch.Tensor,
rank: int,
Expand Down Expand Up @@ -869,10 +869,10 @@ def chunk_dtensor_t2p2p0(
placements=shard_placements,
)

DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p2p0
DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p2p0
DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p3p0
DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p3p0

def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool:
def _is_valid_hybrid_shard_device_mesh_t2p3p0(device_mesh: DeviceMesh) -> bool:
#parent_mesh = _mesh_resources.get_parent_mesh(device_mesh)
#if parent_mesh is not None:
# raise RuntimeError(
Expand All @@ -881,13 +881,13 @@ def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool:
# )
return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2

def _init_process_group_state_for_hybrid_shard_t2p2p0(
def _init_process_group_state_for_hybrid_shard_t2p3p0(
state: _FSDPState,
process_group: ProcessGroupType,
device_mesh: DeviceMesh,
) -> _FSDPState:
if device_mesh:
if _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh):
if _is_valid_hybrid_shard_device_mesh_t2p3p0(device_mesh):
state._device_mesh = device_mesh
# We currently only allow _inter_node_pg to be the outermost dimension, and the
# process_group(intra_node) to be the innermost dimension.
Expand Down Expand Up @@ -917,7 +917,7 @@ def _init_process_group_state_for_hybrid_shard_t2p2p0(
state._inter_node_state = _get_default_comm_hook_state(process_group=state._inter_node_pg,)
return state

def _init_process_group_state_t2p2p0(
def _init_process_group_state_t2p3p0(
state: _FSDPState,
process_group: ProcessGroupType,
sharding_strategy: ShardingStrategy,
Expand All @@ -938,7 +938,7 @@ def _init_process_group_state_t2p2p0(
'requires explicit specification of process group or device_mesh.',
)
else:
state = _init_process_group_state_for_hybrid_shard_t2p2p0(state, process_group, device_mesh)
state = _init_process_group_state_for_hybrid_shard_t2p3p0(state, process_group, device_mesh)
else:
if device_mesh:
state._device_mesh = device_mesh
Expand All @@ -956,7 +956,7 @@ def _init_process_group_state_t2p2p0(
state._gradient_postdivide_factor = (data_parallel_world_size / state._gradient_predivide_factor)
return state

def init_fn_t2p2p0(
def init_fn_t2p3p0(
self,
module: nn.Module,
process_group: ProcessGroupType = None,
Expand Down Expand Up @@ -990,7 +990,7 @@ def init_fn_t2p2p0(
# Note that this is done before auto_wrapping, so that child FSDP modules simply pick up
# the same process group state as the root FSDP module.
self._device_mesh = device_mesh
_init_process_group_state_t2p2p0(
_init_process_group_state_t2p3p0(
self,
process_group,
sharding_strategy,
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def init_fn_t2p2p0(

from torch.distributed.checkpoint.state_dict import StateDictOptions, _StateDictInfo

def _verify_options_t2p2p0(
def _verify_options_t2p3p0(
model: nn.Module,
optims: Tuple[torch.optim.Optimizer, ...],
optim_only: bool,
Expand Down
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2531,7 +2531,7 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_num_samples / current_batch_size)

if use_grad_scaling:
microbatch_loss = cast(torch.Tensor, self.state.scaler.scale(microbatch_loss))
microbatch_loss = cast(torch.Tensor, self.state.scaler.scale(microbatch_loss)) # type: ignore

if self.state.deepspeed_enabled:
self.state.deepspeed_model.backward(microbatch_loss)
Expand Down
1 change: 1 addition & 0 deletions composer/utils/fx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def apply_stochastic_residual(gm: GraphModule, drop_rate: float = 0.2) -> Tuple[
f'Input to apply_stochastic_residual should be an instance of GraphModule. Received {type(gm)}')
all_tags, count = _tag_residual_nodes(gm)
split_gm = split_by_tags(gm, all_tags)
assert isinstance(split_gm, GraphModule)
for node in split_gm.graph.nodes:
if node.op != 'call_module':
continue
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ def test_stop_sequences_criteria(tiny_gpt2_tokenizer):
seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
seq1 = [tiny_gpt2_tokenizer.pad_token_id] * (len(seq2) - len(seq1)) + seq1
input_ids = torch.LongTensor([seq1, seq2])
assert not eos_criteria(input_ids, None)
assert not eos_criteria(input_ids, None) # type: ignore

eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2)
seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
input_ids = torch.LongTensor([seq1, seq2])
assert eos_criteria(input_ids, None)
assert eos_criteria(input_ids, None) # type: ignore


def test_stop_sequences_criteria_sentencepiece(tiny_llama_tokenizer):
Expand All @@ -238,13 +238,13 @@ def test_stop_sequences_criteria_sentencepiece(tiny_llama_tokenizer):
seq2 = tokenizer('Dogs are furry\n\n')['input_ids']
seq1 = [tokenizer.eos_token_id] * (len(seq2) - len(seq1)) + seq1
input_ids = torch.LongTensor([seq1, seq2])
assert not eos_criteria(input_ids, None)
assert not eos_criteria(input_ids, None) # type: ignore

eos_criteria = MultiTokenEOSCriteria('\n\n', tokenizer, 2)
seq1 = tokenizer('Dogs are furry\n\n')['input_ids']
seq2 = tokenizer('Dogs are furry\n\n')['input_ids']
input_ids = torch.LongTensor([seq1, seq2])
assert eos_criteria(input_ids, None)
assert eos_criteria(input_ids, None) # type: ignore


@pytest.mark.filterwarnings(
Expand Down

0 comments on commit c8f1385

Please sign in to comment.