From 68c4b65e93d9e5c16d04935e8d567cb74cf4886a Mon Sep 17 00:00:00 2001 From: "Charles O. Goddard" Date: Sat, 30 Nov 2024 13:55:02 -0800 Subject: [PATCH 1/2] Better tied weight handling (#464) Handle cases where some input models have a tied tensor and some don't. For example, there are some fine tunes of Llama 3.2 3B floating around that are ~3.6B parameters because they have a separate LM head - with these changes these can be merged with standard sized ones. There will be a LM head in the output model if any inputs have one. Otherwise behavior will be as it was before. --- .../_data/architectures/bert-masked-lm.json | 3 +- .../architectures/distilbert-masked-lm.json | 3 +- mergekit/_data/architectures/gemma2.json | 5 +- mergekit/_data/architectures/gptbigcode.json | 4 +- mergekit/_data/architectures/internlm2.json | 3 +- mergekit/_data/architectures/llama.json | 5 +- mergekit/_data/architectures/mamba.json | 5 +- mergekit/_data/architectures/phi3-small.json | 5 +- mergekit/_data/architectures/qwen2.json | 3 +- .../architectures/roberta-masked-lm.json | 7 ++- mergekit/_data/architectures/solar.json | 3 +- mergekit/_data/architectures/starcoder2.json | 5 +- mergekit/architecture.py | 4 ++ mergekit/io/tasks.py | 6 +- mergekit/io/tensor_writer.py | 2 +- mergekit/plan.py | 5 +- mergekit/scripts/tokensurgeon.py | 58 +++++++++++++------ pyproject.toml | 2 +- 18 files changed, 91 insertions(+), 37 deletions(-) diff --git a/mergekit/_data/architectures/bert-masked-lm.json b/mergekit/_data/architectures/bert-masked-lm.json index 3b0620fb..d6430e40 100644 --- a/mergekit/_data/architectures/bert-masked-lm.json +++ b/mergekit/_data/architectures/bert-masked-lm.json @@ -44,7 +44,8 @@ }, { "name": "cls.predictions.decoder.weight", - "aliases": [ + "optional": true, + "tied_names": [ "bert.embeddings.word_embeddings.weight" ], "is_embed": true diff --git a/mergekit/_data/architectures/distilbert-masked-lm.json b/mergekit/_data/architectures/distilbert-masked-lm.json index 6828cca2..1a079811 100644 --- a/mergekit/_data/architectures/distilbert-masked-lm.json +++ b/mergekit/_data/architectures/distilbert-masked-lm.json @@ -40,7 +40,8 @@ { "name": "vocab_projector.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "distilbert.embeddings.word_embeddings.weight" ] }, diff --git a/mergekit/_data/architectures/gemma2.json b/mergekit/_data/architectures/gemma2.json index 0c6372f0..52505245 100644 --- a/mergekit/_data/architectures/gemma2.json +++ b/mergekit/_data/architectures/gemma2.json @@ -54,7 +54,10 @@ { "name": "lm_head.weight", "is_embed": true, - "optional": true + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] } ] } diff --git a/mergekit/_data/architectures/gptbigcode.json b/mergekit/_data/architectures/gptbigcode.json index 4b086278..c12bac5c 100644 --- a/mergekit/_data/architectures/gptbigcode.json +++ b/mergekit/_data/architectures/gptbigcode.json @@ -21,7 +21,9 @@ }, { "name": "lm_head.weight", - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "transformer.wte.weight" ] } diff --git a/mergekit/_data/architectures/internlm2.json b/mergekit/_data/architectures/internlm2.json index 057bc649..888faa48 100644 --- a/mergekit/_data/architectures/internlm2.json +++ b/mergekit/_data/architectures/internlm2.json @@ -16,7 +16,8 @@ { "name": "output.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.tok_embeddings.weight" ] } diff --git a/mergekit/_data/architectures/llama.json b/mergekit/_data/architectures/llama.json index 7106806b..00918a2c 100644 --- a/mergekit/_data/architectures/llama.json +++ b/mergekit/_data/architectures/llama.json @@ -74,7 +74,10 @@ "name": "lm_head.weight", "input_space": "running_residual", "is_embed": true, - "optional": true + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] } ] } diff --git a/mergekit/_data/architectures/mamba.json b/mergekit/_data/architectures/mamba.json index b3727dba..1c473532 100644 --- a/mergekit/_data/architectures/mamba.json +++ b/mergekit/_data/architectures/mamba.json @@ -16,7 +16,10 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": ["backbone.embeddings.weight"] + "optional": true, + "tied_names": [ + "backbone.embeddings.weight" + ] } ], "num_layers_config_key": "num_hidden_layers", diff --git a/mergekit/_data/architectures/phi3-small.json b/mergekit/_data/architectures/phi3-small.json index 7b3a1e80..f27dfac4 100644 --- a/mergekit/_data/architectures/phi3-small.json +++ b/mergekit/_data/architectures/phi3-small.json @@ -12,8 +12,9 @@ "post_weights": [ { "name": "lm_head.weight", - "is_embed":true, - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "model.embed_tokens.weight" ] }, diff --git a/mergekit/_data/architectures/qwen2.json b/mergekit/_data/architectures/qwen2.json index 638b3630..c7131523 100644 --- a/mergekit/_data/architectures/qwen2.json +++ b/mergekit/_data/architectures/qwen2.json @@ -16,7 +16,8 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.embed_tokens.weight" ] } diff --git a/mergekit/_data/architectures/roberta-masked-lm.json b/mergekit/_data/architectures/roberta-masked-lm.json index 492127a5..1aae76a1 100644 --- a/mergekit/_data/architectures/roberta-masked-lm.json +++ b/mergekit/_data/architectures/roberta-masked-lm.json @@ -8,7 +8,8 @@ "name": "roberta.embeddings.position_embeddings.weight" }, { - "name": "roberta.embeddings.word_embeddings.weight" + "name": "roberta.embeddings.word_embeddings.weight", + "is_embed": true }, { "name": "roberta.embeddings.token_type_embeddings.weight" @@ -43,7 +44,9 @@ }, { "name": "lm_head.decoder.weight", - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "roberta.embeddings.word_embeddings.weight" ] } diff --git a/mergekit/_data/architectures/solar.json b/mergekit/_data/architectures/solar.json index 7bd6a751..78fd5998 100644 --- a/mergekit/_data/architectures/solar.json +++ b/mergekit/_data/architectures/solar.json @@ -73,7 +73,8 @@ "name": "lm_head.weight", "input_space": "running_residual", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.lm_head.weight" ] } diff --git a/mergekit/_data/architectures/starcoder2.json b/mergekit/_data/architectures/starcoder2.json index 851fdd1a..c2266899 100644 --- a/mergekit/_data/architectures/starcoder2.json +++ b/mergekit/_data/architectures/starcoder2.json @@ -13,7 +13,10 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": ["model.embed_tokens.weight"] + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] }, { "name": "model.norm.bias" diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 4c7b4625..40872160 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -41,6 +41,8 @@ class WeightInfo(BaseModel, frozen=True): Indicates whether the weight can be omitted from a model. aliases (Optional[List[str]]): List of alternative names for the weight, if applicable. + tied_names (Optional[List[str]]): + List of names for weights that are tied to this weight, if applicable. force_dtype (Optional[str]): Mandatory dtype for the weight, if applicable. """ @@ -50,7 +52,9 @@ class WeightInfo(BaseModel, frozen=True): input_space: Optional[str] = None output_space: Optional[str] = None optional: bool = False + tied: bool = False aliases: Optional[Tuple[str, ...]] = None + tied_names: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None head_split: Literal[None, "input", "output"] = None is_kq: Optional[bool] = False diff --git a/mergekit/io/tasks.py b/mergekit/io/tasks.py index 70dffc41..499ad4c0 100644 --- a/mergekit/io/tasks.py +++ b/mergekit/io/tasks.py @@ -67,12 +67,15 @@ class LoadTensor(Task[Optional[torch.Tensor]]): device: Optional[str] = None optional: bool = False aliases: Optional[Tuple[str, ...]] = None + tied_names: Optional[Tuple[str, ...]] = None def arguments(self) -> Dict[str, Task]: return {} def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]: - all_names = [self.tensor] + list(self.aliases or []) + all_names = ( + [self.tensor] + list(self.aliases or []) + list(self.tied_names or []) + ) for name in all_names: if name in loader.index.tensor_paths: return name @@ -120,6 +123,7 @@ def arguments(self) -> Dict[str, Task]: device=self.device, optional=wi.optional, aliases=wi.aliases, + tied_names=wi.tied_names, ) for (model, wi) in self.weight_info.items() } diff --git a/mergekit/io/tensor_writer.py b/mergekit/io/tensor_writer.py index 199772ea..9ea58222 100644 --- a/mergekit/io/tensor_writer.py +++ b/mergekit/io/tensor_writer.py @@ -121,7 +121,7 @@ def finalize(self): json.dump( { "metadata": { - "mergekit_version": "0.0.5.1", + "mergekit_version": "0.0.5.2", "total_size": self.total_size, }, "weight_map": self.weight_map, diff --git a/mergekit/plan.py b/mergekit/plan.py index bdcd7004..5b34eddc 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -139,7 +139,10 @@ def plan_tensor( any_weight = False for model, w_in in zip(models, weights_in): index = LoaderCache().get(model).index - if w_in.name in index.tensor_paths: + if any( + name in index.tensor_paths + for name in [w_in.name] + (w_in.aliases or []) + ): any_weight = True break diff --git a/mergekit/scripts/tokensurgeon.py b/mergekit/scripts/tokensurgeon.py index a6715643..ea6dd4bc 100644 --- a/mergekit/scripts/tokensurgeon.py +++ b/mergekit/scripts/tokensurgeon.py @@ -147,26 +147,42 @@ def main( ) if lm_head_info: - old_lm_head = cache.get(model).get_tensor( - lm_head_info.name, aliases=lm_head_info.aliases, device=device - ) - donor_lm_head = cache.get(donor).get_tensor( - donor_lm_head_info.name, aliases=donor_lm_head_info.aliases, device=device - ) + try: + old_lm_head = cache.get(model).get_tensor( + lm_head_info.name, aliases=lm_head_info.aliases, device=device + ) + except KeyError: + if lm_head_info.optional: + logging.info(f"LM head tensor {lm_head_info.name} not found, skipping") + else: + report_issue( + f"Could not load LM head tensor {lm_head_info.name}", + error=True, + ) + old_lm_head = None - LOG.info("Computing new lm_head embeddings") - new_lm_head = get_embeddings( - old_lm_head, - donor_lm_head, - old_vocab, - new_vocab, - common_tokens, - accept_prefix=True, - k=k, - barycentric=barycentric, - cosine_similarity=cosine_similarity, - name=lm_head_info.name, - ) + if old_lm_head is not None: + donor_lm_head = cache.get(donor).get_tensor( + donor_lm_head_info.name, + aliases=donor_lm_head_info.aliases, + device=device, + ) + + LOG.info("Computing new lm_head embeddings") + new_lm_head = get_embeddings( + old_lm_head, + donor_lm_head, + old_vocab, + new_vocab, + common_tokens, + accept_prefix=True, + k=k, + barycentric=barycentric, + cosine_similarity=cosine_similarity, + name=lm_head_info.name, + ) + else: + new_lm_head = None # Save out the new model LOG.info(f"Saving new model to {out_path}") @@ -184,6 +200,10 @@ def main( tensor = cache.get(model).get_tensor( weight_info.name, aliases=weight_info.aliases ) + if tensor is None: + if weight_info.optional: + continue + report_issue(f"Could not load weight tensor {weight_info.name}", error=True) writer.save_tensor(weight_info.name, tensor, clone=merge_options.clone_tensors) writer.finalize() diff --git a/pyproject.toml b/pyproject.toml index 128a5b87..e04fd464 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "mergekit" description = "Tools for merging pre-trained large language models" readme = "README.md" license = { text = "LGPL-3.0-or-later" } -version = "0.0.5.1" +version = "0.0.5.2" authors = [{ name = "Charles Goddard", email = "chargoddard@gmail.com" }] dependencies = [ "torch>=2.0.0", From 8d1a10df83bb9ac9c3854199aed20542c7d08231 Mon Sep 17 00:00:00 2001 From: zsgvivo Date: Sun, 1 Dec 2024 06:18:40 +0800 Subject: [PATCH 2/2] Add methods from https://arxiv.org/abs/2405.07813 (#441) add consensus_ties and consensus_ta method from https://arxiv.org/abs/2405.07813 --- mergekit/merge_methods/__init__.py | 16 +++++++ .../generalized_task_arithmetic.py | 43 +++++++++++++++++-- mergekit/sparsify.py | 16 ++++++- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 007e163e..6dc92023 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -93,6 +93,22 @@ def get(method: str) -> MergeMethod: default_normalize=False, default_rescale=True, ) + + elif method == "consensus_ta": + return GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=SparsificationMethod.consensus_ta, + default_normalize=False, + default_rescale=False, + ) + + elif method == "consensus_ties": + return GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.consensus_ties, + default_normalize=True, + default_rescale=False, + ) raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 214726b7..0bb3f0c7 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -29,7 +29,7 @@ MergeMethod, MergeTensorInput, ) -from mergekit.sparsify import SparsificationMethod, sparsify +from mergekit.sparsify import SparsificationMethod, get_tall_mask, sparsify class ConsensusMethod(str, Enum): @@ -79,6 +79,22 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: default_value=1.0, ) ) + if ( + self.sparsification_method == SparsificationMethod.consensus_ta + or self.sparsification_method == SparsificationMethod.consensus_ties + ): + res.append( + ConfigParameterDef( + name="k", + default_value=1, + ) + ) + res.append( + ConfigParameterDef( + name="lambda", + default_value=1.0, + ) + ) return res def make_task( @@ -133,7 +149,10 @@ def execute( return base # sparsify - if self.method.sparsification_method: + if ( + self.method.sparsification_method + and self.method.sparsification_method != SparsificationMethod.consensus_ta + ): for tv_info in tvs: kwargs = {} if "gamma" in tv_info: @@ -142,7 +161,7 @@ def execute( if "epsilon" in tv_info: kwargs["epsilon"] = tv_info["epsilon"] - tv_info["delta"] = sparsify( + tv_info["sparsified_delta"] = sparsify( tv_info["delta"], density=tv_info["density"], method=self.method.sparsification_method, @@ -150,7 +169,9 @@ def execute( **kwargs, ) - deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) + deltas = torch.stack([tv["sparsified_delta"] for tv in tvs], dim=0) + else: + deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) weights = torch.tensor( [tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device ) @@ -185,6 +206,20 @@ def execute( lambda_factor = tvs[0]["lambda"] mixed_delta *= lambda_factor + if ( + self.method.sparsification_method == SparsificationMethod.consensus_ta + or self.method.sparsification_method == SparsificationMethod.consensus_ties + ): + for tv_info in tvs: + tv_info["tall_mask"] = get_tall_mask( + tv_info["delta"], + tv_info["lambda"], + mixed_delta, + ) + tall_masks = torch.stack([tv["tall_mask"] for tv in tvs], dim=0) + consensus_mask = tall_masks.sum(dim=0) >= tvs[0]["k"] + mixed_delta = mixed_delta * consensus_mask + return (base + mixed_delta).to(base.dtype) def group_label(self) -> Optional[str]: diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index ee6477c3..f782247f 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -23,6 +23,8 @@ class SparsificationMethod(str, Enum): random = "random" magnitude_outliers = "magnitude_outliers" rank_magnitude_sampling = "rank_magnitude_sampling" + consensus_ta = "consensus_ta" + consensus_ties = "consensus_ties" def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor): @@ -177,7 +179,10 @@ def sparsify( rescale: bool = False, epsilon: float = 0.15, ) -> torch.Tensor: - if method == SparsificationMethod.magnitude: + if ( + method == SparsificationMethod.magnitude + or method == SparsificationMethod.consensus_ties + ): return magnitude(tensor, density=density, rescale=rescale) elif method == SparsificationMethod.random: return bernoulli(tensor, density=density, rescale=rescale) @@ -187,3 +192,12 @@ def sparsify( return rank_magnitude(tensor, density=density, rescale=rescale, epsilon=epsilon) else: raise NotImplementedError(method) + + +def get_tall_mask( + delta: torch.Tensor, # individual task vectors + lambda_factor: float, # hyper-parameter lambda for generating TALL masks + mixed_delta: torch.Tensor, # multi-task vector +): + mask = delta.abs() > lambda_factor * (mixed_delta - delta).abs() + return mask