diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 5f60d830..68b7648d 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -39,6 +39,7 @@ def get(method: str) -> MergeMethod: sparsification_method=None, default_normalize=False, default_rescale=False, + default_smooth=False, ) elif method == "ties": return GeneralizedTaskArithmeticMerge( @@ -46,6 +47,7 @@ def get(method: str) -> MergeMethod: sparsification_method=SparsificationMethod.magnitude, default_normalize=True, default_rescale=False, + default_smooth=False, ) elif method == "dare_ties": return GeneralizedTaskArithmeticMerge( @@ -53,6 +55,7 @@ def get(method: str) -> MergeMethod: sparsification_method=SparsificationMethod.random, default_normalize=False, default_rescale=True, + default_smooth=False, ) elif method == "dare_linear": return GeneralizedTaskArithmeticMerge( @@ -60,9 +63,26 @@ def get(method: str) -> MergeMethod: sparsification_method=SparsificationMethod.random, default_normalize=False, default_rescale=True, + default_smooth=False, ) elif method == "model_stock": return ModelStockMerge() + elif method == "sample_ties": + return GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.sample, + default_normalize=False, + default_rescale=True, + default_smooth=False, + ) + elif method == "ranked_ties": + return GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.ranked, + default_normalize=False, + default_rescale=True, + default_smooth=False, + ) raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 2bfbcd74..f614b869 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -39,6 +39,7 @@ class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True): sparsification_method: Optional[SparsificationMethod] default_normalize: bool default_rescale: bool + default_smooth: bool def parameters(self) -> List[ConfigParameterDef]: return [ @@ -49,6 +50,9 @@ def parameters(self) -> List[ConfigParameterDef]: ConfigParameterDef( name="rescale", required=False, default_value=self.default_rescale ), + ConfigParameterDef( + name="smooth", required=False, default_value=self.default_smooth + ), ] def tensor_parameters(self) -> List[ConfigParameterDef]: @@ -73,6 +77,7 @@ def make_task( int8_mask=parameters["int8_mask"], normalize=parameters["normalize"], rescale=parameters["rescale"], + smooth=parameters["smooth"], out_tensor_name=output_weight.name, ) @@ -86,6 +91,7 @@ class GTATask(Task[torch.Tensor]): int8_mask: bool normalize: bool rescale: bool + smooth: bool def uses_accelerator(self) -> bool: return True @@ -116,6 +122,7 @@ def execute( density=tv_info["density"], method=self.method.sparsification_method, rescale=self.rescale, + smooth=self.smooth, ) deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 69c923ac..6319408a 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -21,6 +21,8 @@ class SparsificationMethod(str, Enum): magnitude = "magnitude" random = "random" + sample = "sample" + ranked = "ranked" def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor): @@ -78,15 +80,94 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens return res.to(tensor.dtype) +def ranked( + tensor: torch.Tensor, density: float, rescale: bool, smooth: bool +) -> torch.Tensor: + if density >= 1: + return tensor + + # Handle if the tensor is already sparser than the density (In line with trimming). + if ((tensor.abs() ** 0.0).mean() / (tensor.abs() ** 0.0).max()) <= density: + return tensor + + work_dtype = tensor.dtype + size = int(tensor.view(-1).shape[0]) + + mask = torch.zeros_like(tensor) + w = tensor.abs().view(-1) + if w.device.type == "cpu": + w = w.float() + sort = torch.argsort(w, descending=True) + + mask.view(-1)[sort] = torch.linspace( + 1, 0, steps=size, device=w.device.type, dtype=work_dtype + ).pow((1 / density) - 1) + if smooth: + mask = torch.bernoulli(mask) + + if not rescale: + res = rescale_sum(tensor, mask) + else: + res = tensor * mask + + return res + + +def sample( + tensor: torch.Tensor, density: float, rescale: bool, smooth: bool +) -> torch.Tensor: + """Samples the tensor as it's own mask, then shifts mean to fit density.""" + if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float("inf"): + return tensor + + # Handle if the tensor is already sparser than the density (In line with trimming). + if ((tensor.abs() ** 0.0).mean() / (tensor.abs() ** 0.0).max()) <= density: + return tensor + + if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16: + work_dtype = tensor.dtype + else: + # torch.bernoulli not implemented for float16 on CPU, upcast to float32 + work_dtype = torch.float32 + + # Find the power that makes the distribution fit the density + i = 0 + power = 1.0 + avg = tensor.abs().mean() / tensor.abs().max() + while (avg - density) <= 1e-5 and i < 15: + intermediate = tensor.abs() ** power + avg = intermediate.mean() / intermediate.max() + power += avg - density + if power < 0: + power = 0 + i += 1 + + intermediate = tensor.abs() ** power + mask = (intermediate / intermediate.max()).to(work_dtype) + if not smooth: + mask = torch.bernoulli(mask) + + if rescale: + res = rescale_sum(tensor, mask) + else: + res = tensor * mask + return res.to(tensor.dtype) + + def sparsify( tensor: torch.Tensor, density: float, method: SparsificationMethod, - rescale: bool = False, + rescale: bool, + smooth: bool, ) -> torch.Tensor: if method == SparsificationMethod.magnitude: return magnitude(tensor, density=density, rescale=rescale) elif method == SparsificationMethod.random: return bernoulli(tensor, density=density, rescale=rescale) + elif method == SparsificationMethod.sample: + return sample(tensor, density=density, rescale=rescale, smooth=smooth) + elif method == SparsificationMethod.ranked: + return ranked(tensor, density=density, rescale=rescale, smooth=smooth) else: raise NotImplementedError(method)