Skip to content

Commit

Permalink
Better tied weight handling (#464)
Browse files Browse the repository at this point in the history
Handle cases where some input models have a tied tensor and some don't.

For example, there are some fine tunes of Llama 3.2 3B floating around
that are ~3.6B parameters because they have a separate LM head - with
these changes these can be merged with standard sized ones. There will
be a LM head in the output model if any inputs have one. Otherwise
behavior will be as it was before.
  • Loading branch information
cg123 authored Nov 30, 2024
1 parent afe3780 commit 68c4b65
Show file tree
Hide file tree
Showing 18 changed files with 91 additions and 37 deletions.
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/bert-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
},
{
"name": "cls.predictions.decoder.weight",
"aliases": [
"optional": true,
"tied_names": [
"bert.embeddings.word_embeddings.weight"
],
"is_embed": true
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/distilbert-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
{
"name": "vocab_projector.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"distilbert.embeddings.word_embeddings.weight"
]
},
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/gemma2.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"optional": true
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
]
}
4 changes: 3 additions & 1 deletion mergekit/_data/architectures/gptbigcode.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
},
{
"name": "lm_head.weight",
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"transformer.wte.weight"
]
}
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/internlm2.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
{
"name": "output.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.tok_embeddings.weight"
]
}
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@
"name": "lm_head.weight",
"input_space": "running_residual",
"is_embed": true,
"optional": true
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
]
}
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/mamba.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": ["backbone.embeddings.weight"]
"optional": true,
"tied_names": [
"backbone.embeddings.weight"
]
}
],
"num_layers_config_key": "num_hidden_layers",
Expand Down
5 changes: 3 additions & 2 deletions mergekit/_data/architectures/phi3-small.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"post_weights": [
{
"name": "lm_head.weight",
"is_embed":true,
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
},
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/qwen2.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
Expand Down
7 changes: 5 additions & 2 deletions mergekit/_data/architectures/roberta-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"name": "roberta.embeddings.position_embeddings.weight"
},
{
"name": "roberta.embeddings.word_embeddings.weight"
"name": "roberta.embeddings.word_embeddings.weight",
"is_embed": true
},
{
"name": "roberta.embeddings.token_type_embeddings.weight"
Expand Down Expand Up @@ -43,7 +44,9 @@
},
{
"name": "lm_head.decoder.weight",
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"roberta.embeddings.word_embeddings.weight"
]
}
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/solar.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
"name": "lm_head.weight",
"input_space": "running_residual",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.lm_head.weight"
]
}
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/starcoder2.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": ["model.embed_tokens.weight"]
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
},
{
"name": "model.norm.bias"
Expand Down
4 changes: 4 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class WeightInfo(BaseModel, frozen=True):
Indicates whether the weight can be omitted from a model.
aliases (Optional[List[str]]):
List of alternative names for the weight, if applicable.
tied_names (Optional[List[str]]):
List of names for weights that are tied to this weight, if applicable.
force_dtype (Optional[str]):
Mandatory dtype for the weight, if applicable.
"""
Expand All @@ -50,7 +52,9 @@ class WeightInfo(BaseModel, frozen=True):
input_space: Optional[str] = None
output_space: Optional[str] = None
optional: bool = False
tied: bool = False
aliases: Optional[Tuple[str, ...]] = None
tied_names: Optional[Tuple[str, ...]] = None
force_dtype: Optional[str] = None
head_split: Literal[None, "input", "output"] = None
is_kq: Optional[bool] = False
Expand Down
6 changes: 5 additions & 1 deletion mergekit/io/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,15 @@ class LoadTensor(Task[Optional[torch.Tensor]]):
device: Optional[str] = None
optional: bool = False
aliases: Optional[Tuple[str, ...]] = None
tied_names: Optional[Tuple[str, ...]] = None

def arguments(self) -> Dict[str, Task]:
return {}

def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]:
all_names = [self.tensor] + list(self.aliases or [])
all_names = (
[self.tensor] + list(self.aliases or []) + list(self.tied_names or [])
)
for name in all_names:
if name in loader.index.tensor_paths:
return name
Expand Down Expand Up @@ -120,6 +123,7 @@ def arguments(self) -> Dict[str, Task]:
device=self.device,
optional=wi.optional,
aliases=wi.aliases,
tied_names=wi.tied_names,
)
for (model, wi) in self.weight_info.items()
}
Expand Down
2 changes: 1 addition & 1 deletion mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def finalize(self):
json.dump(
{
"metadata": {
"mergekit_version": "0.0.5.1",
"mergekit_version": "0.0.5.2",
"total_size": self.total_size,
},
"weight_map": self.weight_map,
Expand Down
5 changes: 4 additions & 1 deletion mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def plan_tensor(
any_weight = False
for model, w_in in zip(models, weights_in):
index = LoaderCache().get(model).index
if w_in.name in index.tensor_paths:
if any(
name in index.tensor_paths
for name in [w_in.name] + (w_in.aliases or [])
):
any_weight = True
break

Expand Down
58 changes: 39 additions & 19 deletions mergekit/scripts/tokensurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,42 @@ def main(
)

if lm_head_info:
old_lm_head = cache.get(model).get_tensor(
lm_head_info.name, aliases=lm_head_info.aliases, device=device
)
donor_lm_head = cache.get(donor).get_tensor(
donor_lm_head_info.name, aliases=donor_lm_head_info.aliases, device=device
)
try:
old_lm_head = cache.get(model).get_tensor(
lm_head_info.name, aliases=lm_head_info.aliases, device=device
)
except KeyError:
if lm_head_info.optional:
logging.info(f"LM head tensor {lm_head_info.name} not found, skipping")
else:
report_issue(
f"Could not load LM head tensor {lm_head_info.name}",
error=True,
)
old_lm_head = None

LOG.info("Computing new lm_head embeddings")
new_lm_head = get_embeddings(
old_lm_head,
donor_lm_head,
old_vocab,
new_vocab,
common_tokens,
accept_prefix=True,
k=k,
barycentric=barycentric,
cosine_similarity=cosine_similarity,
name=lm_head_info.name,
)
if old_lm_head is not None:
donor_lm_head = cache.get(donor).get_tensor(
donor_lm_head_info.name,
aliases=donor_lm_head_info.aliases,
device=device,
)

LOG.info("Computing new lm_head embeddings")
new_lm_head = get_embeddings(
old_lm_head,
donor_lm_head,
old_vocab,
new_vocab,
common_tokens,
accept_prefix=True,
k=k,
barycentric=barycentric,
cosine_similarity=cosine_similarity,
name=lm_head_info.name,
)
else:
new_lm_head = None

# Save out the new model
LOG.info(f"Saving new model to {out_path}")
Expand All @@ -184,6 +200,10 @@ def main(
tensor = cache.get(model).get_tensor(
weight_info.name, aliases=weight_info.aliases
)
if tensor is None:
if weight_info.optional:
continue
report_issue(f"Could not load weight tensor {weight_info.name}", error=True)
writer.save_tensor(weight_info.name, tensor, clone=merge_options.clone_tensors)
writer.finalize()

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "mergekit"
description = "Tools for merging pre-trained large language models"
readme = "README.md"
license = { text = "LGPL-3.0-or-later" }
version = "0.0.5.1"
version = "0.0.5.2"
authors = [{ name = "Charles Goddard", email = "[email protected]" }]
dependencies = [
"torch>=2.0.0",
Expand Down

0 comments on commit 68c4b65

Please sign in to comment.