Skip to content

Commit

Permalink
Implement DARE merge
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 10, 2023
1 parent b06bc1c commit d707173
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 252 deletions.
2 changes: 1 addition & 1 deletion mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def num_layers(self, config: PretrainedConfig) -> int:
"transformer.ln_f.bias",
"score.weight",
],
embed_weight_names=GPT2_INFO.embed_weight_names,
layer_prefix_format="transformer.h.{idx}",
embed_weight_names=GPT2_INFO.embed_weight_names,
layer_weight_suffixes=GPT2_INFO.layer_weight_suffixes,
num_layers_key=GPT2_INFO.num_layers_key,
)
Expand Down
38 changes: 30 additions & 8 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,57 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from mergekit.merge_methods.base import MergeMethod
from mergekit.merge_methods.generalized_task_arithmetic import (
ConsensusMethod,
GeneralizedTaskArithmeticMerge,
SparsificationMethod,
)
from mergekit.merge_methods.linear import LinearMerge
from mergekit.merge_methods.passthrough import PassthroughMerge
from mergekit.merge_methods.slerp import SlerpMerge
from mergekit.merge_methods.taskarithmetic import TaskArithmeticMerge
from mergekit.merge_methods.ties import TiesMerge
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge


def get(method: str) -> MergeMethod:
if method == "ties":
return TiesMerge()
elif method == "linear":
if method == "linear":
return LinearMerge()
elif method == "slerp":
return SlerpMerge()
elif method == "passthrough":
return PassthroughMerge()
elif method == "task_arithmetic":
return TaskArithmeticMerge()
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=None,
default_normalize=False,
)
elif method == "ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude,
default_normalize=True,
)
elif method == "dare_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.rescaled_random,
default_normalize=False,
)
elif method == "dare_linear":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.rescaled_random,
default_normalize=False,
)
raise RuntimeError(f"Unimplemented merge method {method}")


__all__ = [
"MergeMethod",
"get",
"LinearMerge",
"TiesMerge",
"SlerpMerge",
"PassthroughMerge",
"TaskArithmeticMerge",
"GeneralizedTaskArithmeticMerge",
"TokenizerPermutationMerge",
]
9 changes: 7 additions & 2 deletions mergekit/merge_methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,10 @@ def input_layer_dependencies(
def model_out_config(self, config: MergeConfiguration) -> PretrainedConfig:
"""Return a configuration for the resulting model."""
if config.base_model:
return ModelReference.parse(config.base_model).config()
return config.referenced_models()[0].config()
res = ModelReference.parse(config.base_model).config()
else:
res = config.referenced_models()[0].config()

if config.dtype:
res.torch_dtype = config.dtype
return res
166 changes: 166 additions & 0 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (C) 2023 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import logging
from enum import Enum
from typing import Dict, List, Optional, Tuple

import torch
from pydantic import BaseModel
from typing_extensions import Literal

from mergekit.config import ConfigReader
from mergekit.graph import TensorReference
from mergekit.merge_methods.base import MergeMethod
from mergekit.sparsify import SparsificationMethod, sparsify


class ConsensusMethod(str, Enum):
count = "count"
sum = "sum"


class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel):
consensus_method: Optional[ConsensusMethod]
sparsification_method: Optional[SparsificationMethod]
default_normalize: bool

def __call__(
self,
parameter_name: str,
input_tensors: Dict[TensorReference, torch.Tensor],
config: ConfigReader,
**kwargs,
) -> torch.Tensor:
# collect task vectors
tvs, base = get_task_vectors(
parameter_name,
config,
input_tensors,
required_parameters=["weight"],
optional_parameters=["density"],
)
if not tvs:
return base

# sparsify
if self.sparsification_method:
for tv_info in tvs:
tv_info["delta"] = sparsify(
tv_info["delta"],
density=tv_info.get("density", 1.0),
method=self.sparsification_method,
)

deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
weights = torch.tensor(
[tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device
)
while len(deltas.shape) > len(weights.shape):
weights.unsqueeze_(-1)

weighted_deltas = deltas * weights

# get sign consensus and mix deltas
if self.consensus_method:
mask_dtype = (
torch.int8
if config.parameter("int8_mask", default=False)
else base.dtype
)
mask = get_mask(
weighted_deltas, method=self.consensus_method, mask_dtype=mask_dtype
)
mixed_delta = (weighted_deltas * mask).sum(dim=0)
divisor = (weights * mask).sum(dim=0)
divisor[divisor == 0] = 1
else:
mixed_delta = weighted_deltas.sum(dim=0)
divisor = weights.sum(dim=0)
divisor[divisor.abs() < 1e-8] = 1

if config.parameter("normalize", default=self.default_normalize):
mixed_delta /= divisor

return (base + mixed_delta).to(base.dtype)


def get_task_vectors(
parameter_name: str,
config: ConfigReader,
input_tensors: Dict[TensorReference, torch.Tensor],
required_parameters: Optional[List[str]] = None,
optional_parameters: Optional[List[str]] = None,
) -> Tuple[List[torch.Tensor], List[float], torch.Tensor]:
tensors = {tr.model: value for (tr, value) in input_tensors.items()}
keys = list(tensors.keys())
base = tensors[config.base_model]

res = []
for model in keys:
if model == config.base_model:
continue

x = tensors[model].to(base.dtype)
if x.shape != base.shape:
if "lm_head" in parameter_name or "embed_tokens" in parameter_name:
x = x[: base.shape[0], : base.shape[1]]
logging.warning(f"Using submatrix of {model}:{parameter_name}")
else:
logging.warning(
f"skipping {model}:{parameter_name} due to size mismatch"
)
continue

delta = x - base
del x
del tensors[model]

d = {}
d["model"] = model
d["delta"] = delta
for p in required_parameters:
d[p] = config.parameter(p, model, required=True)
for p in optional_parameters:
d[p] = config.parameter(p, model, required=False)
res.append(d)
return res, base


def get_mask(
delta: torch.Tensor,
method: Literal["sum", "count"] = "sum",
mask_dtype: Optional[torch.dtype] = None,
):
"""Returns a mask determining which delta vectors should be merged
into the final model.
For the methodology described in the paper use 'sum'. For a
simpler naive count of signs, use 'count'."""
if mask_dtype is None:
mask_dtype = delta.dtype

sign = delta.sign().to(mask_dtype)

if method == "sum":
sign_weight = (sign * delta.abs()).sum(dim=0)
majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1
del sign_weight
elif method == "count":
majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1
else:
raise RuntimeError(f'Unimplemented mask method "{method}"')

return sign == majority_sign
99 changes: 0 additions & 99 deletions mergekit/merge_methods/taskarithmetic.py

This file was deleted.

Loading

0 comments on commit d707173

Please sign in to comment.