From e2570ed387295743eaf4333d148666b297e486c0 Mon Sep 17 00:00:00 2001 From: "Charles O. Goddard" Date: Fri, 28 Jun 2024 19:30:33 -0700 Subject: [PATCH 1/2] Support Qwen2 models with tied weights (#358) Should resolve #350. --- mergekit/_data/architectures/qwen2.json | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mergekit/_data/architectures/qwen2.json b/mergekit/_data/architectures/qwen2.json index 292c1e52..638b3630 100644 --- a/mergekit/_data/architectures/qwen2.json +++ b/mergekit/_data/architectures/qwen2.json @@ -15,7 +15,10 @@ }, { "name": "lm_head.weight", - "is_embed": true + "is_embed": true, + "aliases": [ + "model.embed_tokens.weight" + ] } ], "num_layers_config_key": "num_hidden_layers", From 21937cd2b37f8b6d193843f67eb12695d1ae4b96 Mon Sep 17 00:00:00 2001 From: "Charles O. Goddard" Date: Fri, 28 Jun 2024 20:43:31 -0700 Subject: [PATCH 2/2] Gemma2 support (#359) Plus some minor bugfixes. --- mergekit/_data/architectures/gemma2.json | 62 ++++++++++++++++++++++++ mergekit/io/tensor_writer.py | 2 +- mergekit/merge.py | 4 +- pyproject.toml | 4 +- 4 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 mergekit/_data/architectures/gemma2.json diff --git a/mergekit/_data/architectures/gemma2.json b/mergekit/_data/architectures/gemma2.json new file mode 100644 index 00000000..aeca0cc8 --- /dev/null +++ b/mergekit/_data/architectures/gemma2.json @@ -0,0 +1,62 @@ +{ + "model_type": "gemma2", + "architectures": [ + "Gemma2ForCausalLM" + ], + "pre_weights": [ + { + "name": "model.embed_tokens.weight", + "is_embed": true + } + ], + "num_layers_config_key": "num_hidden_layers", + "layer_templates": { + "weights": [ + { + "name": "model.layers.${layer_index}.input_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.k_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.v_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.o_proj.weight" + }, + { + "name": "model.layers.${layer_index}.post_attention_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.up_proj.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.gate_proj.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.down_proj.weight" + }, + { + "name": "model.layers.${layer_index}.post_feedforward_layernorm.weight" + } + ] + }, + "post_weights": [ + { + "name": "model.norm.weight" + }, + { + "name": "lm_head.weight", + "is_embed": true, + "aliases": [ + "model.embed_tokens.weight" + ] + } + ] +} diff --git a/mergekit/io/tensor_writer.py b/mergekit/io/tensor_writer.py index bd34fee2..1483a3c3 100644 --- a/mergekit/io/tensor_writer.py +++ b/mergekit/io/tensor_writer.py @@ -121,7 +121,7 @@ def finalize(self): json.dump( { "metadata": { - "mergekit_version": "0.0.4.2", + "mergekit_version": "0.0.4.4", "total_size": self.total_size, }, "weight_map": self.weight_map, diff --git a/mergekit/merge.py b/mergekit/merge.py index d045644c..abdf85a3 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -179,7 +179,9 @@ def _model_out_config( res = config.base_model.config(trust_remote_code=trust_remote_code) else: res = config.referenced_models()[0].config(trust_remote_code=trust_remote_code) - if config.dtype: + if config.out_dtype: + res.torch_dtype = config.out_dtype + elif config.dtype: res.torch_dtype = config.dtype if config.slices: diff --git a/pyproject.toml b/pyproject.toml index 9a0e9db1..7cf524a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "mergekit" description = "Tools for merging pre-trained large language models" readme = "README.md" license = { text = "LGPL-3.0-or-later" } -version = "0.0.4.3" +version = "0.0.4.4" authors = [{ name = "Charles Goddard", email = "chargoddard@gmail.com" }] dependencies = [ "torch>=2.0.0", @@ -17,7 +17,7 @@ dependencies = [ "accelerate~=0.30.1", "pydantic==2.7.1", "immutables==0.20", - "transformers>=4.39.3", + "transformers>=4.42.3", "huggingface_hub", "peft", "typing-extensions",