Skip to content

Commit

Permalink
Fix typo of Block.
Browse files Browse the repository at this point in the history
  • Loading branch information
xkszltl committed Jan 29, 2024
1 parent 0548af5 commit 91e7b0a
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def forward(
}


class MixtralBLockSparseTop2MLP(nn.Module):
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
Expand All @@ -805,6 +805,14 @@ def forward(self, hidden_states):
return current_hidden_states


class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP):
def __init__(self, *args, **kwargs):
logger.warning_once(
"MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
)
super().__init__(*args, **kwargs)


class MixtralSparseMoeBlock(nn.Module):
"""
This implementation is
Expand All @@ -827,7 +835,7 @@ def __init__(self, config):
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
Expand Down

0 comments on commit 91e7b0a

Please sign in to comment.