Skip to content

Commit

Permalink
Experimental Mixtral support
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 12, 2023
1 parent ca80afe commit 247b3da
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,38 @@ def num_layers(self, config: PretrainedConfig) -> int:
)


class MixtralTensorNames(ArchitectureInfo):
architecture_name: str = "MixtralForCausalLM"

def __init__(self, config: PretrainedConfig):
self.config = config

def pre_weights(self) -> List[str]:
return MISTRAL_INFO.pre_weights()

def post_weights(self) -> List[str]:
return MISTRAL_INFO.post_weights()

def embed_weights(self) -> List[str]:
return MISTRAL_INFO.embed_weights()

def num_layers_config_key(self) -> str:
return MISTRAL_INFO.num_layers_config_key()

def layer_weight_formats(self) -> List[str]:
num_experts = self.config.num_local_experts
res = [fmt for fmt in MISTRAL_INFO.layer_weight_formats() if not ".mlp." in fmt]
for expert_idx in range(num_experts):
for param in ("w1", "w2", "w3"):
fmt = (
MISTRAL_INFO.layer_prefix_format
+ f".block_sparse_moe.experts.{expert_idx}.{param}.weight"
)
res.append(fmt)
res.append(MISTRAL_INFO.layer_prefix_format + ".block_sparse_moe.gate.weight")
return res


STABLELM_INFO = StaticTensorNames(
name="StableLMEpochForCausalLM",
post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"],
Expand Down Expand Up @@ -289,6 +321,8 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
arch_name = config.architectures[0]
if arch_name == PhiTensorNames.architecture_name:
return PhiTensorNames(config)
if arch_name == MixtralTensorNames.architecture_name:
return MixtralTensorNames(config)

supported = [
LLAMA_INFO,
Expand Down

0 comments on commit 247b3da

Please sign in to comment.