From d707173296b044de9a6150ac3d9759978ef06fbd Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Tue, 21 Nov 2023 13:02:53 -0800 Subject: [PATCH] Implement DARE merge --- mergekit/architecture.py | 2 +- mergekit/merge_methods/__init__.py | 38 +++- mergekit/merge_methods/base.py | 9 +- .../generalized_task_arithmetic.py | 166 ++++++++++++++++++ mergekit/merge_methods/taskarithmetic.py | 99 ----------- mergekit/merge_methods/ties.py | 142 --------------- mergekit/sparsify.py | 53 ++++++ 7 files changed, 257 insertions(+), 252 deletions(-) create mode 100644 mergekit/merge_methods/generalized_task_arithmetic.py delete mode 100644 mergekit/merge_methods/taskarithmetic.py delete mode 100644 mergekit/merge_methods/ties.py create mode 100644 mergekit/sparsify.py diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 9c376e8f..08b66466 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -181,8 +181,8 @@ def num_layers(self, config: PretrainedConfig) -> int: "transformer.ln_f.bias", "score.weight", ], - embed_weight_names=GPT2_INFO.embed_weight_names, layer_prefix_format="transformer.h.{idx}", + embed_weight_names=GPT2_INFO.embed_weight_names, layer_weight_suffixes=GPT2_INFO.layer_weight_suffixes, num_layers_key=GPT2_INFO.num_layers_key, ) diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index a958d72c..8c47fcc2 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -14,25 +14,48 @@ # along with this program. If not, see http://www.gnu.org/licenses/. from mergekit.merge_methods.base import MergeMethod +from mergekit.merge_methods.generalized_task_arithmetic import ( + ConsensusMethod, + GeneralizedTaskArithmeticMerge, + SparsificationMethod, +) from mergekit.merge_methods.linear import LinearMerge from mergekit.merge_methods.passthrough import PassthroughMerge from mergekit.merge_methods.slerp import SlerpMerge -from mergekit.merge_methods.taskarithmetic import TaskArithmeticMerge -from mergekit.merge_methods.ties import TiesMerge from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge def get(method: str) -> MergeMethod: - if method == "ties": - return TiesMerge() - elif method == "linear": + if method == "linear": return LinearMerge() elif method == "slerp": return SlerpMerge() elif method == "passthrough": return PassthroughMerge() elif method == "task_arithmetic": - return TaskArithmeticMerge() + return GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=None, + default_normalize=False, + ) + elif method == "ties": + return GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.magnitude, + default_normalize=True, + ) + elif method == "dare_ties": + return GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.rescaled_random, + default_normalize=False, + ) + elif method == "dare_linear": + return GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=SparsificationMethod.rescaled_random, + default_normalize=False, + ) raise RuntimeError(f"Unimplemented merge method {method}") @@ -40,9 +63,8 @@ def get(method: str) -> MergeMethod: "MergeMethod", "get", "LinearMerge", - "TiesMerge", "SlerpMerge", "PassthroughMerge", - "TaskArithmeticMerge", + "GeneralizedTaskArithmeticMerge", "TokenizerPermutationMerge", ] diff --git a/mergekit/merge_methods/base.py b/mergekit/merge_methods/base.py index 14344883..aabbacae 100644 --- a/mergekit/merge_methods/base.py +++ b/mergekit/merge_methods/base.py @@ -48,5 +48,10 @@ def input_layer_dependencies( def model_out_config(self, config: MergeConfiguration) -> PretrainedConfig: """Return a configuration for the resulting model.""" if config.base_model: - return ModelReference.parse(config.base_model).config() - return config.referenced_models()[0].config() + res = ModelReference.parse(config.base_model).config() + else: + res = config.referenced_models()[0].config() + + if config.dtype: + res.torch_dtype = config.dtype + return res diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py new file mode 100644 index 00000000..77f66c56 --- /dev/null +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -0,0 +1,166 @@ +# Copyright (C) 2023 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import torch +from pydantic import BaseModel +from typing_extensions import Literal + +from mergekit.config import ConfigReader +from mergekit.graph import TensorReference +from mergekit.merge_methods.base import MergeMethod +from mergekit.sparsify import SparsificationMethod, sparsify + + +class ConsensusMethod(str, Enum): + count = "count" + sum = "sum" + + +class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel): + consensus_method: Optional[ConsensusMethod] + sparsification_method: Optional[SparsificationMethod] + default_normalize: bool + + def __call__( + self, + parameter_name: str, + input_tensors: Dict[TensorReference, torch.Tensor], + config: ConfigReader, + **kwargs, + ) -> torch.Tensor: + # collect task vectors + tvs, base = get_task_vectors( + parameter_name, + config, + input_tensors, + required_parameters=["weight"], + optional_parameters=["density"], + ) + if not tvs: + return base + + # sparsify + if self.sparsification_method: + for tv_info in tvs: + tv_info["delta"] = sparsify( + tv_info["delta"], + density=tv_info.get("density", 1.0), + method=self.sparsification_method, + ) + + deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) + weights = torch.tensor( + [tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device + ) + while len(deltas.shape) > len(weights.shape): + weights.unsqueeze_(-1) + + weighted_deltas = deltas * weights + + # get sign consensus and mix deltas + if self.consensus_method: + mask_dtype = ( + torch.int8 + if config.parameter("int8_mask", default=False) + else base.dtype + ) + mask = get_mask( + weighted_deltas, method=self.consensus_method, mask_dtype=mask_dtype + ) + mixed_delta = (weighted_deltas * mask).sum(dim=0) + divisor = (weights * mask).sum(dim=0) + divisor[divisor == 0] = 1 + else: + mixed_delta = weighted_deltas.sum(dim=0) + divisor = weights.sum(dim=0) + divisor[divisor.abs() < 1e-8] = 1 + + if config.parameter("normalize", default=self.default_normalize): + mixed_delta /= divisor + + return (base + mixed_delta).to(base.dtype) + + +def get_task_vectors( + parameter_name: str, + config: ConfigReader, + input_tensors: Dict[TensorReference, torch.Tensor], + required_parameters: Optional[List[str]] = None, + optional_parameters: Optional[List[str]] = None, +) -> Tuple[List[torch.Tensor], List[float], torch.Tensor]: + tensors = {tr.model: value for (tr, value) in input_tensors.items()} + keys = list(tensors.keys()) + base = tensors[config.base_model] + + res = [] + for model in keys: + if model == config.base_model: + continue + + x = tensors[model].to(base.dtype) + if x.shape != base.shape: + if "lm_head" in parameter_name or "embed_tokens" in parameter_name: + x = x[: base.shape[0], : base.shape[1]] + logging.warning(f"Using submatrix of {model}:{parameter_name}") + else: + logging.warning( + f"skipping {model}:{parameter_name} due to size mismatch" + ) + continue + + delta = x - base + del x + del tensors[model] + + d = {} + d["model"] = model + d["delta"] = delta + for p in required_parameters: + d[p] = config.parameter(p, model, required=True) + for p in optional_parameters: + d[p] = config.parameter(p, model, required=False) + res.append(d) + return res, base + + +def get_mask( + delta: torch.Tensor, + method: Literal["sum", "count"] = "sum", + mask_dtype: Optional[torch.dtype] = None, +): + """Returns a mask determining which delta vectors should be merged + into the final model. + + For the methodology described in the paper use 'sum'. For a + simpler naive count of signs, use 'count'.""" + if mask_dtype is None: + mask_dtype = delta.dtype + + sign = delta.sign().to(mask_dtype) + + if method == "sum": + sign_weight = (sign * delta.abs()).sum(dim=0) + majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1 + del sign_weight + elif method == "count": + majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1 + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + + return sign == majority_sign diff --git a/mergekit/merge_methods/taskarithmetic.py b/mergekit/merge_methods/taskarithmetic.py deleted file mode 100644 index d2b037b4..00000000 --- a/mergekit/merge_methods/taskarithmetic.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (C) 2023 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -import logging -from typing import Dict, List, Optional, Tuple - -import torch - -from mergekit.config import ConfigReader -from mergekit.graph import TensorReference -from mergekit.merge_methods.base import MergeMethod - - -class TaskArithmeticMerge(MergeMethod): - def __call__( - self, - parameter_name: str, - input_tensors: Dict[TensorReference, torch.Tensor], - config: ConfigReader, - **kwargs, - ) -> torch.Tensor: - deltas, weights, base = get_task_vectors(parameter_name, config, input_tensors) - - if deltas: - deltas = torch.stack(deltas, dim=0) - weights = torch.tensor(weights, dtype=deltas.dtype, device=deltas.device) - while len(deltas.shape) > len(weights.shape): - weights.unsqueeze_(-1) - - mixed_delta = (weights * deltas).sum(dim=0) - - if config.parameter("normalize", default=False): - divisor = weights.sum(dim=0) - divisor[divisor.abs() < 1e-8] = 1 - mixed_delta /= divisor - - res = base + mixed_delta - else: - res = base - - return res.to(base.dtype) - - -def get_task_vectors( - parameter_name: str, - config: ConfigReader, - input_tensors: Dict[TensorReference, torch.Tensor], - skip_same: bool = False, - default_weight: Optional[float] = None, -) -> Tuple[List[torch.Tensor], List[float], torch.Tensor]: - tensors = {tr.model: value for (tr, value) in input_tensors.items()} - keys = list(tensors.keys()) - base = tensors[config.base_model] - - deltas = [] - weights = [] - for model in keys: - if model == config.base_model: - continue - - x = tensors[model].to(base.dtype) - if x.shape != base.shape: - if "lm_head" in parameter_name or "embed_tokens" in parameter_name: - x = x[: base.shape[0], : base.shape[1]] - logging.warning(f"Using submatrix of {model}:{parameter_name}") - else: - logging.warning( - f"skipping {model}:{parameter_name} due to size mismatch" - ) - continue - - if skip_same and (x == base).view(-1).all(): - continue - - deltas.append(x - base) - weights.append( - config.parameter( - "weight", - model, - default=default_weight, - required=default_weight is None, - ) - ) - - del tensors[model] - del x - return deltas, weights, base diff --git a/mergekit/merge_methods/ties.py b/mergekit/merge_methods/ties.py deleted file mode 100644 index bb1e3f22..00000000 --- a/mergekit/merge_methods/ties.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (C) 2023 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -import logging -from typing import Dict, Optional - -import torch -from typing_extensions import Literal - -from mergekit.config import ConfigReader -from mergekit.graph import TensorReference -from mergekit.merge_methods.base import MergeMethod - - -class TiesMerge(MergeMethod): - def __call__( - self, - parameter_name: str, - input_tensors: Dict[TensorReference, torch.Tensor], - config: ConfigReader, - **kwargs, - ) -> torch.Tensor: - tensors = {tr.model: value for (tr, value) in input_tensors.items()} - base = tensors[config.base_model] - - # resolve dtype for mask - mask_dtype = ( - torch.int8 if config.parameter("int8_mask", default=False) else base.dtype - ) - - deltas = [] - weights = [] - keys = list(tensors.keys()) - for model in keys: - if model == config.base_model: - continue - - x = tensors[model].to(base.dtype) - if x.shape != base.shape: - if "lm_head" in parameter_name or "embed_tokens" in parameter_name: - x = x[: base.shape[0], : base.shape[1]] - logging.warning(f"Using submatrix of {model}:{parameter_name}") - else: - logging.warning( - f"skipping {model}:{parameter_name} due to size mismatch" - ) - continue - - if (x == base).view(-1).all(): - continue - - deltas.append( - sparsify(x - base, config.parameter("density", model, default=0.33)) - ) - weights.append(config.parameter("weight", model, default=1.0)) - - del tensors[model] - del x - - if deltas: - deltas = torch.stack(deltas, dim=0) - weights = torch.tensor(weights, dtype=deltas.dtype, device=deltas.device) - while len(deltas.shape) > len(weights.shape): - weights.unsqueeze_(-1) - - weighted_deltas = weights * deltas - - mask = get_mask( - weighted_deltas, - method=config.parameter("consensus_method", default="sum"), - mask_dtype=mask_dtype, - ) - - mixed_delta = (weighted_deltas * mask).sum(dim=0) - - if config.parameter("normalize", default=True): - divisor = (weights * mask).sum(dim=0) - divisor[divisor == 0] = 1 - mixed_delta /= divisor - - res = base + mixed_delta - else: - res = base - - return res.to(base.dtype) - - -def sparsify(tensor: torch.Tensor, density: float) -> torch.Tensor: - """Masks out the smallest values, retaining a proportion of `density`.""" - if density >= 1: - return tensor - - k = int(density * tensor.view(-1).shape[0]) - - assert k > 0, "not gonna zero out the whole tensor buddy" - mask = torch.zeros_like(tensor) - w = tensor.abs().view(-1) - if w.device.type == "cpu": - w = w.float() - topk = torch.topk(w, k=k, largest=True) - mask.view(-1)[topk.indices] = 1 - - return tensor * mask - - -def get_mask( - delta: torch.Tensor, - method: Literal["sum", "count"] = "sum", - mask_dtype: Optional[torch.dtype] = None, -): - """Returns a mask determining which delta vectors should be merged - into the final model. - - For the methodology described in the paper use 'sum'. For a - simpler naive count of signs, use 'count'.""" - if mask_dtype is None: - mask_dtype = delta.dtype - - sign = delta.sign().to(mask_dtype) - - if method == "sum": - sign_weight = (sign * delta.abs()).sum(dim=0) - majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1 - del sign_weight - elif method == "count": - majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1 - else: - raise RuntimeError(f'Unimplemented mask method "{method}"') - - return sign == majority_sign diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py new file mode 100644 index 00000000..fa5cc192 --- /dev/null +++ b/mergekit/sparsify.py @@ -0,0 +1,53 @@ +from enum import Enum + +import torch + + +class SparsificationMethod(str, Enum): + magnitude = "magnitude" + random = "random" + rescaled_random = "rescaled_random" + + +def magnitude(tensor: torch.Tensor, density: float) -> torch.Tensor: + """Masks out the smallest values, retaining a proportion of `density`.""" + if density >= 1: + return tensor + + k = int(density * tensor.view(-1).shape[0]) + + assert k > 0, "not gonna zero out the whole tensor buddy" + mask = torch.zeros_like(tensor) + w = tensor.abs().view(-1) + if w.device.type == "cpu": + w = w.float() + topk = torch.topk(w, k=k, largest=True) + mask.view(-1)[topk.indices] = 1 + + return tensor * mask + + +def bernoulli( + tensor: torch.Tensor, density: float, rescale: bool = True +) -> torch.Tensor: + if density >= 1: + return tensor + + mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) + res = tensor * mask + if rescale: + res /= density + return res + + +def sparsify( + tensor: torch.Tensor, density: float, method: SparsificationMethod +) -> torch.Tensor: + if method == SparsificationMethod.magnitude: + return magnitude(tensor, density=density) + elif method == SparsificationMethod.random: + return bernoulli(tensor, density=density, rescale=False) + elif method == SparsificationMethod.rescaled_random: + return bernoulli(tensor, density=density, rescale=True) + else: + raise NotImplementedError(method)