Skip to content

Commit

Permalink
NuSLERP (#357)
Browse files Browse the repository at this point in the history
Adds a new merge method `nuslerp`. This method allows for a superset of
the functionality of `slerp`. If provided with a base model `nuslerp`
will perform spherical interpolation of the task vectors. While the
original `slerp` always flattens weight tensors into a single dimension
`nuslerp` can also do row-wise and column-wise interpolation of tensors.

This method remedies one of my long-standing gripes with how I
implemented `slerp`. Instead of taking a `t` parameter and using
`base_model` to specify which is the "first" model, `nuslerp` simply
takes a `weight` parameter for each model and computes the interpolation
factor `t` internally. This makes it fit the conventions of the other
merge methods much better. The `weight` parameter behaves in the same
fashion as it does for `merge_method: linear` with `normalize: true`.

The idea to add task vector SLERP is inspired by DeepMind's great use of
it in their [WARP](https://arxiv.org/abs/2406.16768) paper.
  • Loading branch information
cg123 authored Dec 8, 2024
1 parent 00f8bf4 commit ea44575
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 2 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,9 @@ A quick overview of the currently supported merge methods:
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) | `breadcrumbs` | ✅ | ✅ |
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` | ✅ | ✅ |
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ |
| NuSLERP | `nuslerp` | ❌ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ |

### Linear

Expand Down Expand Up @@ -313,6 +314,18 @@ Parameters:

- `filter_wise`: if true, weight calculation will be per-row rather than per-tensor. Not recommended.

### NuSLERP

Spherically interpolate between parameters, but with more options and more sensical configuration! Does not require a base model, but can use one to do spherical interpolation of task vectors. Only works with either two models or two plus a base model.

Parameters:

- `weight`: relative weighting of a given tensor
- `nuslerp_flatten`: set to false to do row-wise/column-wise interpolation instead of treating tensors as vectors
- `nuslerp_row_wise`: SLERP row vectors instead of column vectors

To replicate the behavior of the original `slerp` method, set `weight` to `1-t` and `t` for your first and second model respectively.

### [DELLA](https://arxiv.org/abs/2406.11617)

Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES
Expand Down
3 changes: 3 additions & 0 deletions mergekit/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def method_md(merge_method: str) -> str:
"dare_ties": "[DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708)",
"dare_linear": "linear [DARE](https://arxiv.org/abs/2311.03099)",
"model_stock": "[Model Stock](https://arxiv.org/abs/2403.19522)",
"della": "[DELLA](https://arxiv.org/abs/2406.11617)",
"della_linear": "linear [DELLA](https://arxiv.org/abs/2406.11617)",
"nuslerp": "NuSLERP",
}
return methods.get(merge_method, merge_method)

Expand Down
3 changes: 3 additions & 0 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from mergekit.merge_methods.linear import LinearMerge
from mergekit.merge_methods.model_stock import ModelStockMerge
from mergekit.merge_methods.nuslerp import NuSlerpMerge
from mergekit.merge_methods.passthrough import PassthroughMerge
from mergekit.merge_methods.slerp import SlerpMerge
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge
Expand All @@ -31,6 +32,8 @@ def get(method: str) -> MergeMethod:
return LinearMerge()
elif method == "slerp":
return SlerpMerge()
elif method == "nuslerp":
return NuSlerpMerge()
elif method == "passthrough":
return PassthroughMerge()
elif method == "task_arithmetic":
Expand Down
168 changes: 168 additions & 0 deletions mergekit/merge_methods/nuslerp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from typing import Any, Dict, List, Optional

import torch
from torch._tensor import Tensor

from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.rectify_embed import rectify_embed_sizes


class NuSlerpTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]]
weight_info: WeightInfo
row_wise: bool
flatten: bool
base_model: Optional[ModelReference]

def uses_accelerator(self) -> bool:
return True

def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}

def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor:
if len(tensors) == 1:
return list(tensors.values())[0]

if self.base_model is not None:
if len(tensors) != 3:
raise RuntimeError(
"NuSlerp base model can not be one of the two models to merge"
)
base_tensor = tensors.pop(self.base_model)
else:
base_tensor = None

keys = list(tensors.keys())
tensors = [tensors[key] for key in keys]
weights = [self.tensor_parameters[key]["weight"] for key in keys]

if len(tensors) != 2:
print(keys)
print(self.base_model)
raise RuntimeError(
"NuSlerp merge expects exactly two models (plus optional base model)"
)

if abs(sum(weights)) < 1e-6:
# this is fairly arbitrary, but it's more sane than exploding
t = 0.5
else:
t = weights[1] / sum(weights)

if base_tensor is not None:
tensors.append(base_tensor)
rectify_embed_sizes(self.weight_info, tensors)

if base_tensor is not None:
base_tensor = tensors.pop()
return base_tensor + nuslerp(
t,
tensors[0] - base_tensor,
tensors[1] - base_tensor,
dim=0 if self.row_wise else -1,
flatten=self.flatten,
)
return nuslerp(
t,
tensors[0],
tensors[1],
dim=0 if self.row_wise else -1,
flatten=self.flatten,
)


class NuSlerpMerge(MergeMethod):
def parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(
name="nuslerp_row_wise",
required=False,
default_value=False,
),
ConfigParameterDef(
name="nuslerp_flatten",
required=False,
default_value=True,
),
]

def tensor_parameters(self) -> List[ConfigParameterDef]:
return [ConfigParameterDef(name="weight", required=True)]

def make_task(
self,
*,
output_weight: WeightInfo,
tensors: GatherTensors,
base_model: Optional[ModelReference],
parameters: ImmutableMap[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
**_kwargs,
) -> Task:
return NuSlerpTask(
gather_tensors=tensors,
tensor_parameters=tensor_parameters,
weight_info=output_weight,
row_wise=parameters["nuslerp_row_wise"],
flatten=parameters["nuslerp_flatten"],
base_model=base_model,
)


def nuslerp(
t: float,
v0: torch.Tensor,
v1: torch.Tensor,
dim: int = -1,
eps: float = 1e-8,
flatten: bool = False,
):
out_shape = v0.shape

def _normalize(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
return x / torch.norm(x, dim=-1, keepdim=True).clamp(min=eps)

if flatten:
v0 = v0.view(-1)
v1 = v1.view(-1)
elif dim != -1:
v0 = v0.transpose(dim, -1)
v1 = v1.transpose(dim, -1)

v0_u = _normalize(v0)
v1_u = _normalize(v1)

cos_theta = torch.sum(v0_u * v1_u, dim=-1, keepdim=True)
theta = torch.acos(cos_theta.clamp(-1, 1))
sin_theta = torch.sin(theta)

colinear = (sin_theta.abs() < eps).squeeze()

res = (torch.sin((1 - t) * theta) * v0 + torch.sin(t * theta) * v1) / sin_theta
# Use linear interpolation for (nearly) colinear vectors
res[colinear] = (1 - t) * v0[colinear] + t * v1[colinear]

if dim != -1 and not flatten:
res = res.transpose(dim, -1)
return res.view(out_shape)
32 changes: 32 additions & 0 deletions tests/test_basic_merges.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,38 @@ def test_slerp_merge(self, model_a, model_b):
config.parameters = {"t": 0.35}
run_and_check_merge(config)

def test_nuslerp_merges(self, model_a, model_b, model_c):
for base_model in [None, model_c]:
for row_wise in [False, True]:
for flatten in [False, True]:
print(
f"Testing nuslerp with row_wise={row_wise}, flatten={flatten}, base_model={base_model}"
)
run_and_check_merge(
self.two_model_config(
model_a,
model_b,
merge_method="nuslerp",
base_model=base_model,
params={
"nuslerp_row_wise": row_wise,
"nuslerp_flatten": flatten,
},
)
)

# test weights that sum to zero
config = self.two_model_config(
model_a,
model_b,
merge_method="nuslerp",
base_model=model_c,
params={"nuslerp_row_wise": False, "nuslerp_flatten": False},
)
config.models[0].parameters["weight"] = -0.5
config.models[1].parameters["weight"] = 0.5
run_and_check_merge(config)

def test_task_arithmetic_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a, model_b, merge_method="task_arithmetic", base_model=model_c
Expand Down

0 comments on commit ea44575

Please sign in to comment.