Skip to content

Commit

Permalink
Merge branch 'main' into pad-embeds-to-multiple
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 authored Dec 1, 2024
2 parents 2ab77ab + 8d1a10d commit b540424
Show file tree
Hide file tree
Showing 21 changed files with 161 additions and 42 deletions.
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/bert-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
},
{
"name": "cls.predictions.decoder.weight",
"aliases": [
"optional": true,
"tied_names": [
"bert.embeddings.word_embeddings.weight"
],
"is_embed": true
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/distilbert-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
{
"name": "vocab_projector.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"distilbert.embeddings.word_embeddings.weight"
]
},
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/gemma2.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"optional": true
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
]
}
4 changes: 3 additions & 1 deletion mergekit/_data/architectures/gptbigcode.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
},
{
"name": "lm_head.weight",
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"transformer.wte.weight"
]
}
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/internlm2.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
{
"name": "output.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.tok_embeddings.weight"
]
}
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@
"name": "lm_head.weight",
"input_space": "running_residual",
"is_embed": true,
"optional": true
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
]
}
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/mamba.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": ["backbone.embeddings.weight"]
"optional": true,
"tied_names": [
"backbone.embeddings.weight"
]
}
],
"num_layers_config_key": "num_hidden_layers",
Expand Down
5 changes: 3 additions & 2 deletions mergekit/_data/architectures/phi3-small.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"post_weights": [
{
"name": "lm_head.weight",
"is_embed":true,
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
},
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/qwen2.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
Expand Down
7 changes: 5 additions & 2 deletions mergekit/_data/architectures/roberta-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"name": "roberta.embeddings.position_embeddings.weight"
},
{
"name": "roberta.embeddings.word_embeddings.weight"
"name": "roberta.embeddings.word_embeddings.weight",
"is_embed": true
},
{
"name": "roberta.embeddings.token_type_embeddings.weight"
Expand Down Expand Up @@ -43,7 +44,9 @@
},
{
"name": "lm_head.decoder.weight",
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"roberta.embeddings.word_embeddings.weight"
]
}
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/solar.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
"name": "lm_head.weight",
"input_space": "running_residual",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.lm_head.weight"
]
}
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/starcoder2.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": ["model.embed_tokens.weight"]
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
},
{
"name": "model.norm.bias"
Expand Down
4 changes: 4 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class WeightInfo(BaseModel, frozen=True):
Indicates whether the weight can be omitted from a model.
aliases (Optional[List[str]]):
List of alternative names for the weight, if applicable.
tied_names (Optional[List[str]]):
List of names for weights that are tied to this weight, if applicable.
force_dtype (Optional[str]):
Mandatory dtype for the weight, if applicable.
"""
Expand All @@ -50,7 +52,9 @@ class WeightInfo(BaseModel, frozen=True):
input_space: Optional[str] = None
output_space: Optional[str] = None
optional: bool = False
tied: bool = False
aliases: Optional[Tuple[str, ...]] = None
tied_names: Optional[Tuple[str, ...]] = None
force_dtype: Optional[str] = None
head_split: Literal[None, "input", "output"] = None
is_kq: Optional[bool] = False
Expand Down
6 changes: 5 additions & 1 deletion mergekit/io/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,15 @@ class LoadTensor(Task[Optional[torch.Tensor]]):
device: Optional[str] = None
optional: bool = False
aliases: Optional[Tuple[str, ...]] = None
tied_names: Optional[Tuple[str, ...]] = None

def arguments(self) -> Dict[str, Task]:
return {}

def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]:
all_names = [self.tensor] + list(self.aliases or [])
all_names = (
[self.tensor] + list(self.aliases or []) + list(self.tied_names or [])
)
for name in all_names:
if name in loader.index.tensor_paths:
return name
Expand Down Expand Up @@ -120,6 +123,7 @@ def arguments(self) -> Dict[str, Task]:
device=self.device,
optional=wi.optional,
aliases=wi.aliases,
tied_names=wi.tied_names,
)
for (model, wi) in self.weight_info.items()
}
Expand Down
2 changes: 1 addition & 1 deletion mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def finalize(self):
json.dump(
{
"metadata": {
"mergekit_version": "0.0.5.1",
"mergekit_version": "0.0.5.2",
"total_size": self.total_size,
},
"weight_map": self.weight_map,
Expand Down
16 changes: 16 additions & 0 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ def get(method: str) -> MergeMethod:
default_normalize=False,
default_rescale=True,
)

elif method == "consensus_ta":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.consensus_ta,
default_normalize=False,
default_rescale=False,
)

elif method == "consensus_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.consensus_ties,
default_normalize=True,
default_rescale=False,
)
raise RuntimeError(f"Unimplemented merge method {method}")


Expand Down
43 changes: 39 additions & 4 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
MergeMethod,
MergeTensorInput,
)
from mergekit.sparsify import SparsificationMethod, sparsify
from mergekit.sparsify import SparsificationMethod, get_tall_mask, sparsify


class ConsensusMethod(str, Enum):
Expand Down Expand Up @@ -79,6 +79,22 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
default_value=1.0,
)
)
if (
self.sparsification_method == SparsificationMethod.consensus_ta
or self.sparsification_method == SparsificationMethod.consensus_ties
):
res.append(
ConfigParameterDef(
name="k",
default_value=1,
)
)
res.append(
ConfigParameterDef(
name="lambda",
default_value=1.0,
)
)
return res

def make_task(
Expand Down Expand Up @@ -133,7 +149,10 @@ def execute(
return base

# sparsify
if self.method.sparsification_method:
if (
self.method.sparsification_method
and self.method.sparsification_method != SparsificationMethod.consensus_ta
):
for tv_info in tvs:
kwargs = {}
if "gamma" in tv_info:
Expand All @@ -142,15 +161,17 @@ def execute(
if "epsilon" in tv_info:
kwargs["epsilon"] = tv_info["epsilon"]

tv_info["delta"] = sparsify(
tv_info["sparsified_delta"] = sparsify(
tv_info["delta"],
density=tv_info["density"],
method=self.method.sparsification_method,
rescale=self.rescale,
**kwargs,
)

deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
deltas = torch.stack([tv["sparsified_delta"] for tv in tvs], dim=0)
else:
deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
weights = torch.tensor(
[tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device
)
Expand Down Expand Up @@ -185,6 +206,20 @@ def execute(
lambda_factor = tvs[0]["lambda"]
mixed_delta *= lambda_factor

if (
self.method.sparsification_method == SparsificationMethod.consensus_ta
or self.method.sparsification_method == SparsificationMethod.consensus_ties
):
for tv_info in tvs:
tv_info["tall_mask"] = get_tall_mask(
tv_info["delta"],
tv_info["lambda"],
mixed_delta,
)
tall_masks = torch.stack([tv["tall_mask"] for tv in tvs], dim=0)
consensus_mask = tall_masks.sum(dim=0) >= tvs[0]["k"]
mixed_delta = mixed_delta * consensus_mask

return (base + mixed_delta).to(base.dtype)

def group_label(self) -> Optional[str]:
Expand Down
5 changes: 4 additions & 1 deletion mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def plan_tensor(
any_weight = False
for model, w_in in zip(models, weights_in):
index = LoaderCache().get(model).index
if w_in.name in index.tensor_paths:
if any(
name in index.tensor_paths
for name in [w_in.name] + (w_in.aliases or [])
):
any_weight = True
break

Expand Down
58 changes: 39 additions & 19 deletions mergekit/scripts/tokensurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,42 @@ def main(
)

if lm_head_info:
old_lm_head = cache.get(model).get_tensor(
lm_head_info.name, aliases=lm_head_info.aliases, device=device
)
donor_lm_head = cache.get(donor).get_tensor(
donor_lm_head_info.name, aliases=donor_lm_head_info.aliases, device=device
)
try:
old_lm_head = cache.get(model).get_tensor(
lm_head_info.name, aliases=lm_head_info.aliases, device=device
)
except KeyError:
if lm_head_info.optional:
logging.info(f"LM head tensor {lm_head_info.name} not found, skipping")
else:
report_issue(
f"Could not load LM head tensor {lm_head_info.name}",
error=True,
)
old_lm_head = None

LOG.info("Computing new lm_head embeddings")
new_lm_head = get_embeddings(
old_lm_head,
donor_lm_head,
old_vocab,
new_vocab,
common_tokens,
accept_prefix=True,
k=k,
barycentric=barycentric,
cosine_similarity=cosine_similarity,
name=lm_head_info.name,
)
if old_lm_head is not None:
donor_lm_head = cache.get(donor).get_tensor(
donor_lm_head_info.name,
aliases=donor_lm_head_info.aliases,
device=device,
)

LOG.info("Computing new lm_head embeddings")
new_lm_head = get_embeddings(
old_lm_head,
donor_lm_head,
old_vocab,
new_vocab,
common_tokens,
accept_prefix=True,
k=k,
barycentric=barycentric,
cosine_similarity=cosine_similarity,
name=lm_head_info.name,
)
else:
new_lm_head = None

# Save out the new model
LOG.info(f"Saving new model to {out_path}")
Expand All @@ -184,6 +200,10 @@ def main(
tensor = cache.get(model).get_tensor(
weight_info.name, aliases=weight_info.aliases
)
if tensor is None:
if weight_info.optional:
continue
report_issue(f"Could not load weight tensor {weight_info.name}", error=True)
writer.save_tensor(weight_info.name, tensor, clone=merge_options.clone_tensors)
writer.finalize()

Expand Down
Loading

0 comments on commit b540424

Please sign in to comment.