From 91e7b0a83dfc3c573a4d9443e40c13501a86632f Mon Sep 17 00:00:00 2001 From: Tongliang Liao Date: Fri, 26 Jan 2024 03:48:14 -0800 Subject: [PATCH] Fix typo of `Block`. --- src/transformers/models/mixtral/modeling_mixtral.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index a3ece48a5aa74b..b8bc13fbe038bc 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -787,7 +787,7 @@ def forward( } -class MixtralBLockSparseTop2MLP(nn.Module): +class MixtralBlockSparseTop2MLP(nn.Module): def __init__(self, config: MixtralConfig): super().__init__() self.ffn_dim = config.intermediate_size @@ -805,6 +805,14 @@ def forward(self, hidden_states): return current_hidden_states +class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP): + def __init__(self, *args, **kwargs): + logger.warning_once( + "MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40." + ) + super().__init__(*args, **kwargs) + + class MixtralSparseMoeBlock(nn.Module): """ This implementation is @@ -827,7 +835,7 @@ def __init__(self, config): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)]) + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """