diff --git a/mergekit/_data/architectures/deepseekv2.json b/mergekit/_data/architectures/deepseekv2.json new file mode 100644 index 00000000..c06a1b85 --- /dev/null +++ b/mergekit/_data/architectures/deepseekv2.json @@ -0,0 +1,50 @@ +{ + "model_type": "deepseekv2", + "architectures": [ + "DeepseekV2ForCausalLM" + ], + "pre_weights": [ + { + "name": "model.embed_tokens.weight", + "is_embed": true + } + ], + "post_weights": [ + { + "name": "model.norm.weight" + }, + { + "name": "lm_head.weight", + "is_embed": true, + "aliases": [ + "model.embed_tokens.weight" + ] + } + ], + "num_layers_config_key": "num_hidden_layers", + "layer_templates": { + "weights": [ + { + "name" : "model.layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name" : "model.layers.${layer_index}.self_attn.kv_a_proj_with_mqa.weight" + }, + { + "name" : "model.layers.${layer_index}.self_attn.kv_a_layernorm.weight" + }, + { + "name" : "model.layers.${layer_index}.self_attn.kv_b_proj.weight" + }, + { + "name" : "model.layers.${layer_index}.self_attn.o_proj.weight" + }, + { + "name": "model.layers.${layer_index}.input_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.post_attention_layernorm.weight" + } + ] + } +} diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 4c7b4625..073e37d2 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -326,6 +326,57 @@ def sliceable(self) -> bool: def has_defined_spaces(self) -> bool: return False +class DeepseekV2TensorNames(ArchitectureInfo, BaseModel): + ARCHITECTURE_NAME: ClassVar[str] = "DeepseekV2ForCausalLM" + num_local_experts: int + + def name(self) -> str: + return "DeepseekV2" + + @classmethod + def from_config(cls, config: PretrainedConfig): + return DeepseekV2TensorNames(num_local_experts=config.n_routed_experts) + + def pre_weights(sef, config: PretrainedConfig) -> List[WeightInfo]: + return DEEPSEEKV2_INFO.pre_weights(config) + + def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + return DEEPSEEKV2_INFO.post_weights(config) + + def num_layers_config_key(self) -> str: + return DEEPSEEKV2_INFO.num_layers_config_key() + + def layer_weights( + self, index: int, config: PretrainedConfig + ) -> Optional[List[WeightInfo]]: + num_experts = self.num_local_experts + prefix = f"model.layers.{index}" + tensor_names = [] + + if index > 0: + for expert_idx in range(num_experts): + for param in ("gate_proj", "up_proj", "down_proj"): + tensor_names.append( + prefix + f".mlp.experts.{expert_idx}.{param}.weight" + ) + tensor_names.append(prefix + ".mlp.gate.weight") + + tensor_names.append(prefix + ".mlp.shared_experts.gate_proj.weight") + tensor_names.append(prefix + ".mlp.shared_experts.up_proj.weight") + tensor_names.append(prefix + ".mlp.shared_experts.down_proj.weight") + + res = [WeightInfo(name=name) for name in tensor_names] + + for weight_info in DEEPSEEKV2_INFO.layer_weights(index, config): + res.append(weight_info) + + return res + + def sliceable(self) -> bool: + return True + + def has_defined_spaces(self) -> bool: + return False def _load_json_arch(name: str) -> JsonArchitectureInfo: text = importlib.resources.read_text(mergekit._data.architectures, name) @@ -353,6 +404,7 @@ def _load_all_architectures() -> ( JSON_ARCHITECTURES, NAME_TO_ARCH = _load_all_architectures() MISTRAL_INFO = _load_json_arch("mistral.json") QWEN2_INFO = _load_json_arch("qwen2.json") +DEEPSEEKV2_INFO = _load_json_arch("deepseekv2.json") def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: