diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index fcb24d258..ca7c2c3bd 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -679,7 +679,7 @@ def __init__( self.disable_grad_reduce = disable_grad_reduce self.explicit_expert_comm = self.is_expert and ( - config.sequence_parallel or self.expert_parallel + config.tensor_model_parallel_size > 1 or self.expert_parallel ) if self.explicit_expert_comm and config.moe_extended_tp: world_size = get_tensor_and_expert_parallel_world_size() @@ -941,7 +941,7 @@ def __init__( raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`") self.explicit_expert_comm = self.is_expert and ( - config.sequence_parallel or self.expert_parallel + config.tensor_model_parallel_size > 1 or self.expert_parallel ) # Divide the weight matrix along the last dimension. diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index ba3750011..d42f409a0 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -90,6 +90,16 @@ def __init__( self.moe_layer_recompute = config.moe_layer_recompute def forward(self, hidden_states: torch.Tensor): + if ( + self.training + and self.config.tensor_model_parallel_size > 1 + and not self.config.sequence_parallel + ): + raise ValueError( + "During training, performance may degrade if MoE and tensor parallelism" + "are enabled without also enabling sequence parallelism." + ) + # process MoE def custom_forward(hidden_states): probs, indices = self.router(hidden_states) diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 515a96ff4..e0e112d94 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -107,7 +107,9 @@ def token_permutation( hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) # Permute the tokens across the expert parallel devices. - if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1): + if (self.config.tensor_model_parallel_size > 1) or ( + self.config.expert_model_parallel_size > 1 + ): with torch.no_grad(): global_indices = tensor_parallel.gather_from_sequence_parallel_region_to_moe( max_ind @@ -214,7 +216,9 @@ def token_unpermutation( output_bias_total = unpermuted_local_bias # Unpermute the tokens across expert parallel devices. - if self.config.sequence_parallel or (self.config.expert_model_parallel_size > 1): + if (self.config.tensor_model_parallel_size > 1) or ( + self.config.expert_model_parallel_size > 1 + ): assert ( self.global_local_map is not None ), "global_local_map is necessary for `AllGather`." diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 6b038669f..c829c52f1 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -501,9 +501,6 @@ def validate_args(args, defaults={}): # MoE Spec check if args.num_experts is not None: assert args.spec is None, "Model Spec must be None when using MoEs" - if args.tensor_model_parallel_size > 1: - assert args.sequence_parallel, \ - "When using MoE and tensor parallelism, sequence parallelism must be used." # Expert parallelism check if args.expert_model_parallel_size > 1: