From ea44575060a6ae62865f59bae4341da0e83cb1d5 Mon Sep 17 00:00:00 2001 From: "Charles O. Goddard" Date: Sat, 7 Dec 2024 18:12:24 -0800 Subject: [PATCH] NuSLERP (#357) Adds a new merge method `nuslerp`. This method allows for a superset of the functionality of `slerp`. If provided with a base model `nuslerp` will perform spherical interpolation of the task vectors. While the original `slerp` always flattens weight tensors into a single dimension `nuslerp` can also do row-wise and column-wise interpolation of tensors. This method remedies one of my long-standing gripes with how I implemented `slerp`. Instead of taking a `t` parameter and using `base_model` to specify which is the "first" model, `nuslerp` simply takes a `weight` parameter for each model and computes the interpolation factor `t` internally. This makes it fit the conventions of the other merge methods much better. The `weight` parameter behaves in the same fashion as it does for `merge_method: linear` with `normalize: true`. The idea to add task vector SLERP is inspired by DeepMind's great use of it in their [WARP](https://arxiv.org/abs/2406.16768) paper. --- README.md | 17 ++- mergekit/card.py | 3 + mergekit/merge_methods/__init__.py | 3 + mergekit/merge_methods/nuslerp.py | 168 +++++++++++++++++++++++++++++ tests/test_basic_merges.py | 32 ++++++ 5 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 mergekit/merge_methods/nuslerp.py diff --git a/README.md b/README.md index 225bd838..c5e166c6 100644 --- a/README.md +++ b/README.md @@ -250,8 +250,9 @@ A quick overview of the currently supported merge methods: | [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) | `breadcrumbs` | ✅ | ✅ | | [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` | ✅ | ✅ | | [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ | -| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ | -| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ | +| NuSLERP | `nuslerp` | ❌ | ✅ | +| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ | +| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ | ### Linear @@ -313,6 +314,18 @@ Parameters: - `filter_wise`: if true, weight calculation will be per-row rather than per-tensor. Not recommended. +### NuSLERP + +Spherically interpolate between parameters, but with more options and more sensical configuration! Does not require a base model, but can use one to do spherical interpolation of task vectors. Only works with either two models or two plus a base model. + +Parameters: + +- `weight`: relative weighting of a given tensor +- `nuslerp_flatten`: set to false to do row-wise/column-wise interpolation instead of treating tensors as vectors +- `nuslerp_row_wise`: SLERP row vectors instead of column vectors + +To replicate the behavior of the original `slerp` method, set `weight` to `1-t` and `t` for your first and second model respectively. + ### [DELLA](https://arxiv.org/abs/2406.11617) Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES diff --git a/mergekit/card.py b/mergekit/card.py index bf0a2d0a..5adb30aa 100644 --- a/mergekit/card.py +++ b/mergekit/card.py @@ -118,6 +118,9 @@ def method_md(merge_method: str) -> str: "dare_ties": "[DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708)", "dare_linear": "linear [DARE](https://arxiv.org/abs/2311.03099)", "model_stock": "[Model Stock](https://arxiv.org/abs/2403.19522)", + "della": "[DELLA](https://arxiv.org/abs/2406.11617)", + "della_linear": "linear [DELLA](https://arxiv.org/abs/2406.11617)", + "nuslerp": "NuSLERP", } return methods.get(merge_method, merge_method) diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 6dc92023..2fd02a08 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -21,6 +21,7 @@ ) from mergekit.merge_methods.linear import LinearMerge from mergekit.merge_methods.model_stock import ModelStockMerge +from mergekit.merge_methods.nuslerp import NuSlerpMerge from mergekit.merge_methods.passthrough import PassthroughMerge from mergekit.merge_methods.slerp import SlerpMerge from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge @@ -31,6 +32,8 @@ def get(method: str) -> MergeMethod: return LinearMerge() elif method == "slerp": return SlerpMerge() + elif method == "nuslerp": + return NuSlerpMerge() elif method == "passthrough": return PassthroughMerge() elif method == "task_arithmetic": diff --git a/mergekit/merge_methods/nuslerp.py b/mergekit/merge_methods/nuslerp.py new file mode 100644 index 00000000..448e669e --- /dev/null +++ b/mergekit/merge_methods/nuslerp.py @@ -0,0 +1,168 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional + +import torch +from torch._tensor import Tensor + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod +from mergekit.merge_methods.rectify_embed import rectify_embed_sizes + + +class NuSlerpTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] + weight_info: WeightInfo + row_wise: bool + flatten: bool + base_model: Optional[ModelReference] + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: + if len(tensors) == 1: + return list(tensors.values())[0] + + if self.base_model is not None: + if len(tensors) != 3: + raise RuntimeError( + "NuSlerp base model can not be one of the two models to merge" + ) + base_tensor = tensors.pop(self.base_model) + else: + base_tensor = None + + keys = list(tensors.keys()) + tensors = [tensors[key] for key in keys] + weights = [self.tensor_parameters[key]["weight"] for key in keys] + + if len(tensors) != 2: + print(keys) + print(self.base_model) + raise RuntimeError( + "NuSlerp merge expects exactly two models (plus optional base model)" + ) + + if abs(sum(weights)) < 1e-6: + # this is fairly arbitrary, but it's more sane than exploding + t = 0.5 + else: + t = weights[1] / sum(weights) + + if base_tensor is not None: + tensors.append(base_tensor) + rectify_embed_sizes(self.weight_info, tensors) + + if base_tensor is not None: + base_tensor = tensors.pop() + return base_tensor + nuslerp( + t, + tensors[0] - base_tensor, + tensors[1] - base_tensor, + dim=0 if self.row_wise else -1, + flatten=self.flatten, + ) + return nuslerp( + t, + tensors[0], + tensors[1], + dim=0 if self.row_wise else -1, + flatten=self.flatten, + ) + + +class NuSlerpMerge(MergeMethod): + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef( + name="nuslerp_row_wise", + required=False, + default_value=False, + ), + ConfigParameterDef( + name="nuslerp_flatten", + required=False, + default_value=True, + ), + ] + + def tensor_parameters(self) -> List[ConfigParameterDef]: + return [ConfigParameterDef(name="weight", required=True)] + + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: GatherTensors, + base_model: Optional[ModelReference], + parameters: ImmutableMap[str, Any], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], + **_kwargs, + ) -> Task: + return NuSlerpTask( + gather_tensors=tensors, + tensor_parameters=tensor_parameters, + weight_info=output_weight, + row_wise=parameters["nuslerp_row_wise"], + flatten=parameters["nuslerp_flatten"], + base_model=base_model, + ) + + +def nuslerp( + t: float, + v0: torch.Tensor, + v1: torch.Tensor, + dim: int = -1, + eps: float = 1e-8, + flatten: bool = False, +): + out_shape = v0.shape + + def _normalize(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + return x / torch.norm(x, dim=-1, keepdim=True).clamp(min=eps) + + if flatten: + v0 = v0.view(-1) + v1 = v1.view(-1) + elif dim != -1: + v0 = v0.transpose(dim, -1) + v1 = v1.transpose(dim, -1) + + v0_u = _normalize(v0) + v1_u = _normalize(v1) + + cos_theta = torch.sum(v0_u * v1_u, dim=-1, keepdim=True) + theta = torch.acos(cos_theta.clamp(-1, 1)) + sin_theta = torch.sin(theta) + + colinear = (sin_theta.abs() < eps).squeeze() + + res = (torch.sin((1 - t) * theta) * v0 + torch.sin(t * theta) * v1) / sin_theta + # Use linear interpolation for (nearly) colinear vectors + res[colinear] = (1 - t) * v0[colinear] + t * v1[colinear] + + if dim != -1 and not flatten: + res = res.transpose(dim, -1) + return res.view(out_shape) diff --git a/tests/test_basic_merges.py b/tests/test_basic_merges.py index ae54de43..84a673aa 100644 --- a/tests/test_basic_merges.py +++ b/tests/test_basic_merges.py @@ -99,6 +99,38 @@ def test_slerp_merge(self, model_a, model_b): config.parameters = {"t": 0.35} run_and_check_merge(config) + def test_nuslerp_merges(self, model_a, model_b, model_c): + for base_model in [None, model_c]: + for row_wise in [False, True]: + for flatten in [False, True]: + print( + f"Testing nuslerp with row_wise={row_wise}, flatten={flatten}, base_model={base_model}" + ) + run_and_check_merge( + self.two_model_config( + model_a, + model_b, + merge_method="nuslerp", + base_model=base_model, + params={ + "nuslerp_row_wise": row_wise, + "nuslerp_flatten": flatten, + }, + ) + ) + + # test weights that sum to zero + config = self.two_model_config( + model_a, + model_b, + merge_method="nuslerp", + base_model=model_c, + params={"nuslerp_row_wise": False, "nuslerp_flatten": False}, + ) + config.models[0].parameters["weight"] = -0.5 + config.models[1].parameters["weight"] = 0.5 + run_and_check_merge(config) + def test_task_arithmetic_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_a, model_b, merge_method="task_arithmetic", base_model=model_c