diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 08b66466..7d44f214 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -108,6 +108,38 @@ def num_layers(self, config: PretrainedConfig) -> int: ) +class MixtralTensorNames(ArchitectureInfo): + architecture_name: str = "MixtralForCausalLM" + + def __init__(self, config: PretrainedConfig): + self.config = config + + def pre_weights(self) -> List[str]: + return MISTRAL_INFO.pre_weights() + + def post_weights(self) -> List[str]: + return MISTRAL_INFO.post_weights() + + def embed_weights(self) -> List[str]: + return MISTRAL_INFO.embed_weights() + + def num_layers_config_key(self) -> str: + return MISTRAL_INFO.num_layers_config_key() + + def layer_weight_formats(self) -> List[str]: + num_experts = self.config.num_local_experts + res = [fmt for fmt in MISTRAL_INFO.layer_weight_formats() if not ".mlp." in fmt] + for expert_idx in range(num_experts): + for param in ("w1", "w2", "w3"): + fmt = ( + MISTRAL_INFO.layer_prefix_format + + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" + ) + res.append(fmt) + res.append(MISTRAL_INFO.layer_prefix_format + ".block_sparse_moe.gate.weight") + return res + + STABLELM_INFO = StaticTensorNames( name="StableLMEpochForCausalLM", post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"], @@ -289,6 +321,8 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames: arch_name = config.architectures[0] if arch_name == PhiTensorNames.architecture_name: return PhiTensorNames(config) + if arch_name == MixtralTensorNames.architecture_name: + return MixtralTensorNames(config) supported = [ LLAMA_INFO,