diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index a3ece48a5aa74b..b8bc13fbe038bc 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -787,7 +787,7 @@ def forward( } -class MixtralBLockSparseTop2MLP(nn.Module): +class MixtralBlockSparseTop2MLP(nn.Module): def __init__(self, config: MixtralConfig): super().__init__() self.ffn_dim = config.intermediate_size @@ -805,6 +805,14 @@ def forward(self, hidden_states): return current_hidden_states +class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP): + def __init__(self, *args, **kwargs): + logger.warning_once( + "MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40." + ) + super().__init__(*args, **kwargs) + + class MixtralSparseMoeBlock(nn.Module): """ This implementation is @@ -827,7 +835,7 @@ def __init__(self, config): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)]) + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """