Skip to content

Commit

Permalink
Merge branch 'main' into nuslerp
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 authored Jun 29, 2024
2 parents 272d887 + 21937cd commit 8a57623
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 5 deletions.
62 changes: 62 additions & 0 deletions mergekit/_data/architectures/gemma2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
{
"model_type": "gemma2",
"architectures": [
"Gemma2ForCausalLM"
],
"pre_weights": [
{
"name": "model.embed_tokens.weight",
"is_embed": true
}
],
"num_layers_config_key": "num_hidden_layers",
"layer_templates": {
"weights": [
{
"name": "model.layers.${layer_index}.input_layernorm.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.q_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.k_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.v_proj.weight"
},
{
"name": "model.layers.${layer_index}.self_attn.o_proj.weight"
},
{
"name": "model.layers.${layer_index}.post_attention_layernorm.weight"
},
{
"name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight"
},
{
"name": "model.layers.${layer_index}.mlp.up_proj.weight"
},
{
"name": "model.layers.${layer_index}.mlp.gate_proj.weight"
},
{
"name": "model.layers.${layer_index}.mlp.down_proj.weight"
},
{
"name": "model.layers.${layer_index}.post_feedforward_layernorm.weight"
}
]
},
"post_weights": [
{
"name": "model.norm.weight"
},
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": [
"model.embed_tokens.weight"
]
}
]
}
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/qwen2.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
},
{
"name": "lm_head.weight",
"is_embed": true
"is_embed": true,
"aliases": [
"model.embed_tokens.weight"
]
}
],
"num_layers_config_key": "num_hidden_layers",
Expand Down
2 changes: 1 addition & 1 deletion mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def finalize(self):
json.dump(
{
"metadata": {
"mergekit_version": "0.0.4.2",
"mergekit_version": "0.0.4.4",
"total_size": self.total_size,
},
"weight_map": self.weight_map,
Expand Down
4 changes: 3 additions & 1 deletion mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def _model_out_config(
res = config.base_model.config(trust_remote_code=trust_remote_code)
else:
res = config.referenced_models()[0].config(trust_remote_code=trust_remote_code)
if config.dtype:
if config.out_dtype:
res.torch_dtype = config.out_dtype
elif config.dtype:
res.torch_dtype = config.dtype

if config.slices:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "mergekit"
description = "Tools for merging pre-trained large language models"
readme = "README.md"
license = { text = "LGPL-3.0-or-later" }
version = "0.0.4.3"
version = "0.0.4.4"
authors = [{ name = "Charles Goddard", email = "[email protected]" }]
dependencies = [
"torch>=2.0.0",
Expand All @@ -17,7 +17,7 @@ dependencies = [
"accelerate~=0.30.1",
"pydantic==2.7.1",
"immutables==0.20",
"transformers>=4.39.3",
"transformers>=4.42.3",
"huggingface_hub",
"peft",
"typing-extensions",
Expand Down

0 comments on commit 8a57623

Please sign in to comment.