-
Notifications
You must be signed in to change notification settings - Fork 458
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a new merge method `nuslerp`. This method allows for a superset of the functionality of `slerp`. If provided with a base model `nuslerp` will perform spherical interpolation of the task vectors. While the original `slerp` always flattens weight tensors into a single dimension `nuslerp` can also do row-wise and column-wise interpolation of tensors. This method remedies one of my long-standing gripes with how I implemented `slerp`. Instead of taking a `t` parameter and using `base_model` to specify which is the "first" model, `nuslerp` simply takes a `weight` parameter for each model and computes the interpolation factor `t` internally. This makes it fit the conventions of the other merge methods much better. The `weight` parameter behaves in the same fashion as it does for `merge_method: linear` with `normalize: true`. The idea to add task vector SLERP is inspired by DeepMind's great use of it in their [WARP](https://arxiv.org/abs/2406.16768) paper.
- Loading branch information
Showing
5 changed files
with
221 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright (C) 2024 Charles O. Goddard | ||
# | ||
# This software is free software: you can redistribute it and/or | ||
# modify it under the terms of the GNU Lesser General Public License as | ||
# published by the Free Software Foundation, either version 3 of the | ||
# License, or (at your option) any later version. | ||
# | ||
# This software is distributed in the hope that it will be useful, but | ||
# WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | ||
# Lesser General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Lesser General Public License | ||
# along with this program. If not, see http://www.gnu.org/licenses/. | ||
|
||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
from torch._tensor import Tensor | ||
|
||
from mergekit.architecture import WeightInfo | ||
from mergekit.common import ImmutableMap, ModelReference | ||
from mergekit.graph import Task | ||
from mergekit.io.tasks import GatherTensors | ||
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod | ||
from mergekit.merge_methods.rectify_embed import rectify_embed_sizes | ||
|
||
|
||
class NuSlerpTask(Task[torch.Tensor]): | ||
gather_tensors: GatherTensors | ||
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] | ||
weight_info: WeightInfo | ||
row_wise: bool | ||
flatten: bool | ||
base_model: Optional[ModelReference] | ||
|
||
def uses_accelerator(self) -> bool: | ||
return True | ||
|
||
def arguments(self) -> Dict[str, Task]: | ||
return {"tensors": self.gather_tensors} | ||
|
||
def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: | ||
if len(tensors) == 1: | ||
return list(tensors.values())[0] | ||
|
||
if self.base_model is not None: | ||
if len(tensors) != 3: | ||
raise RuntimeError( | ||
"NuSlerp base model can not be one of the two models to merge" | ||
) | ||
base_tensor = tensors.pop(self.base_model) | ||
else: | ||
base_tensor = None | ||
|
||
keys = list(tensors.keys()) | ||
tensors = [tensors[key] for key in keys] | ||
weights = [self.tensor_parameters[key]["weight"] for key in keys] | ||
|
||
if len(tensors) != 2: | ||
print(keys) | ||
print(self.base_model) | ||
raise RuntimeError( | ||
"NuSlerp merge expects exactly two models (plus optional base model)" | ||
) | ||
|
||
if abs(sum(weights)) < 1e-6: | ||
# this is fairly arbitrary, but it's more sane than exploding | ||
t = 0.5 | ||
else: | ||
t = weights[1] / sum(weights) | ||
|
||
if base_tensor is not None: | ||
tensors.append(base_tensor) | ||
rectify_embed_sizes(self.weight_info, tensors) | ||
|
||
if base_tensor is not None: | ||
base_tensor = tensors.pop() | ||
return base_tensor + nuslerp( | ||
t, | ||
tensors[0] - base_tensor, | ||
tensors[1] - base_tensor, | ||
dim=0 if self.row_wise else -1, | ||
flatten=self.flatten, | ||
) | ||
return nuslerp( | ||
t, | ||
tensors[0], | ||
tensors[1], | ||
dim=0 if self.row_wise else -1, | ||
flatten=self.flatten, | ||
) | ||
|
||
|
||
class NuSlerpMerge(MergeMethod): | ||
def parameters(self) -> List[ConfigParameterDef]: | ||
return [ | ||
ConfigParameterDef( | ||
name="nuslerp_row_wise", | ||
required=False, | ||
default_value=False, | ||
), | ||
ConfigParameterDef( | ||
name="nuslerp_flatten", | ||
required=False, | ||
default_value=True, | ||
), | ||
] | ||
|
||
def tensor_parameters(self) -> List[ConfigParameterDef]: | ||
return [ConfigParameterDef(name="weight", required=True)] | ||
|
||
def make_task( | ||
self, | ||
*, | ||
output_weight: WeightInfo, | ||
tensors: GatherTensors, | ||
base_model: Optional[ModelReference], | ||
parameters: ImmutableMap[str, Any], | ||
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], | ||
**_kwargs, | ||
) -> Task: | ||
return NuSlerpTask( | ||
gather_tensors=tensors, | ||
tensor_parameters=tensor_parameters, | ||
weight_info=output_weight, | ||
row_wise=parameters["nuslerp_row_wise"], | ||
flatten=parameters["nuslerp_flatten"], | ||
base_model=base_model, | ||
) | ||
|
||
|
||
def nuslerp( | ||
t: float, | ||
v0: torch.Tensor, | ||
v1: torch.Tensor, | ||
dim: int = -1, | ||
eps: float = 1e-8, | ||
flatten: bool = False, | ||
): | ||
out_shape = v0.shape | ||
|
||
def _normalize(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: | ||
return x / torch.norm(x, dim=-1, keepdim=True).clamp(min=eps) | ||
|
||
if flatten: | ||
v0 = v0.view(-1) | ||
v1 = v1.view(-1) | ||
elif dim != -1: | ||
v0 = v0.transpose(dim, -1) | ||
v1 = v1.transpose(dim, -1) | ||
|
||
v0_u = _normalize(v0) | ||
v1_u = _normalize(v1) | ||
|
||
cos_theta = torch.sum(v0_u * v1_u, dim=-1, keepdim=True) | ||
theta = torch.acos(cos_theta.clamp(-1, 1)) | ||
sin_theta = torch.sin(theta) | ||
|
||
colinear = (sin_theta.abs() < eps).squeeze() | ||
|
||
res = (torch.sin((1 - t) * theta) * v0 + torch.sin(t * theta) * v1) / sin_theta | ||
# Use linear interpolation for (nearly) colinear vectors | ||
res[colinear] = (1 - t) * v0[colinear] + t * v1[colinear] | ||
|
||
if dim != -1 and not flatten: | ||
res = res.transpose(dim, -1) | ||
return res.view(out_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters