Skip to content

Commit

Permalink
Add Della merge method (#366)
Browse files Browse the repository at this point in the history
Adds a new merging method della. Della first ranks parameters in each
row of delta parameters and assigns drop probabilities adaptively,
inversely proportional to their magnitudes. Delta parameters with higher
magnitudes are assigned lower drop probabilities. After assigning drop
probabilities, the delta parameters are dropped and rescaled in a manner
similar to the DARE method. The Della-merging paper can be found
[here](https://arxiv.org/abs/2406.11617)
  • Loading branch information
Tej-Deep authored Jul 20, 2024
1 parent 6447a85 commit 619f4e4
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 1 deletion.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ A quick overview of the currently supported merge methods:
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) | `breadcrumbs` |||
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` |||
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` |||

| [DELLA](https://arxiv.org/abs/2406.11617) | `della` |||
| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` |||
### Linear

The classic merge method - a simple weighted average.
Expand Down Expand Up @@ -189,6 +190,15 @@ Parameters:

- `filter_wise`: if true, weight calculation will be per-row rather than per-tensor. Not recommended.

### [DELLA](https://arxiv.org/abs/2406.11617)

Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES

Parameters: same as [Linear](#linear), plus:
- `density` - fraction of weights in differences from the base model to retain
- `epsilon` - maximum change in drop probability based on magnitude. Drop probabilities assigned will range from `density - epsilon` to `density + epsilon`. (When selecting values for `density` and `epsilon`, ensure that the range of probabilities falls within 0 to 1)
- `lambda` - scaling factor for the final merged delta parameters before merging with the base parameters.

## LoRA extraction

Mergekit allows extracting PEFT-compatible low-rank approximations of finetuned models.
Expand Down
16 changes: 16 additions & 0 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ def get(method: str) -> MergeMethod:
)
elif method == "model_stock":
return ModelStockMerge()

elif method == "della":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.rank_magnitude_sampling,
default_normalize=True,
default_rescale=True,
)

elif method == "della_linear":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.rank_magnitude_sampling,
default_normalize=False,
default_rescale=True,
)
raise RuntimeError(f"Unimplemented merge method {method}")


Expand Down
23 changes: 23 additions & 0 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
default_value=0.01,
)
)
if self.sparsification_method == SparsificationMethod.rank_magnitude_sampling:
res.append(
ConfigParameterDef(
name="epsilon",
default_value=0.15,
)
)
res.append(
ConfigParameterDef(
name="lambda",
default_value=1.0,
)
)
return res

def make_task(
Expand Down Expand Up @@ -126,6 +139,9 @@ def execute(
if "gamma" in tv_info:
kwargs["gamma"] = tv_info["gamma"]

if "epsilon" in tv_info:
kwargs["epsilon"] = tv_info["epsilon"]

tv_info["delta"] = sparsify(
tv_info["delta"],
density=tv_info["density"],
Expand Down Expand Up @@ -162,6 +178,13 @@ def execute(
if self.normalize:
mixed_delta /= divisor

if (
self.method.sparsification_method
== SparsificationMethod.rank_magnitude_sampling
):
lambda_factor = tvs[0]["lambda"]
mixed_delta *= lambda_factor

return (base + mixed_delta).to(base.dtype)

def group_label(self) -> Optional[str]:
Expand Down
54 changes: 54 additions & 0 deletions mergekit/sparsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class SparsificationMethod(str, Enum):
magnitude = "magnitude"
random = "random"
magnitude_outliers = "magnitude_outliers"
rank_magnitude_sampling = "rank_magnitude_sampling"


def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor):
Expand Down Expand Up @@ -115,21 +116,74 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens
res = tensor.to(work_dtype) * mask
if rescale:
res /= density

return res.to(tensor.dtype)


def rank_magnitude(
tensor: torch.Tensor, density: float, rescale: bool = True, epsilon: float = 0.05
) -> torch.Tensor:
if density >= 1:
return tensor

if density <= epsilon or density >= (1 - epsilon):
raise ValueError(
f"Error: density +- epsilon must be in the range (0, 1). density + epsilon = {density+epsilon}, density - epsilon = {density-epsilon}"
)

if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16:
work_dtype = tensor.dtype
else:
work_dtype = torch.float32

if len(tensor.shape) < 2:
tensor = tensor.unsqueeze(0)

# Get Rank matrix for the delta values
tensor_abs = torch.abs(tensor)

sorted_indices = torch.argsort(tensor_abs, dim=1, descending=False)

ranking_tensor = torch.zeros_like(tensor_abs, dtype=work_dtype)
for i in range(tensor_abs.size(0)):
ranking_tensor[i][sorted_indices[i]] = torch.arange(
1, tensor.size(1) + 1, dtype=work_dtype
).to(tensor.device)

# Normalise rank matrix to the probability range to density +- epsilon
range_vals = (
ranking_tensor.max(dim=1, keepdim=True).values
- ranking_tensor.min(dim=1, keepdim=True).values
)
norm_metrics = (ranking_tensor - ranking_tensor.min(dim=1, keepdim=True).values) / (
range_vals
)
final_probabilities = (density - epsilon) + norm_metrics * (2 * epsilon)

mask = torch.bernoulli(final_probabilities).to(work_dtype)
res = tensor.to(work_dtype) * mask

if rescale:
res = res / (final_probabilities.to(work_dtype))

return res.squeeze(0)


def sparsify(
tensor: torch.Tensor,
density: float,
method: SparsificationMethod,
gamma: float = 0,
rescale: bool = False,
epsilon: float = 0.15,
) -> torch.Tensor:
if method == SparsificationMethod.magnitude:
return magnitude(tensor, density=density, rescale=rescale)
elif method == SparsificationMethod.random:
return bernoulli(tensor, density=density, rescale=rescale)
elif method == SparsificationMethod.magnitude_outliers:
return magnitude_outliers(tensor, density=density, rescale=rescale, gamma=gamma)
elif method == SparsificationMethod.rank_magnitude_sampling:
return rank_magnitude(tensor, density=density, rescale=rescale, epsilon=epsilon)
else:
raise NotImplementedError(method)

0 comments on commit 619f4e4

Please sign in to comment.