Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
add consensus_ties and consensus_ta method from
https://arxiv.org/abs/2405.07813
  • Loading branch information
zsgvivo authored Nov 30, 2024
1 parent 68c4b65 commit 8d1a10d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
16 changes: 16 additions & 0 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ def get(method: str) -> MergeMethod:
default_normalize=False,
default_rescale=True,
)

elif method == "consensus_ta":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.consensus_ta,
default_normalize=False,
default_rescale=False,
)

elif method == "consensus_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.consensus_ties,
default_normalize=True,
default_rescale=False,
)
raise RuntimeError(f"Unimplemented merge method {method}")


Expand Down
43 changes: 39 additions & 4 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
MergeMethod,
MergeTensorInput,
)
from mergekit.sparsify import SparsificationMethod, sparsify
from mergekit.sparsify import SparsificationMethod, get_tall_mask, sparsify


class ConsensusMethod(str, Enum):
Expand Down Expand Up @@ -79,6 +79,22 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
default_value=1.0,
)
)
if (
self.sparsification_method == SparsificationMethod.consensus_ta
or self.sparsification_method == SparsificationMethod.consensus_ties
):
res.append(
ConfigParameterDef(
name="k",
default_value=1,
)
)
res.append(
ConfigParameterDef(
name="lambda",
default_value=1.0,
)
)
return res

def make_task(
Expand Down Expand Up @@ -133,7 +149,10 @@ def execute(
return base

# sparsify
if self.method.sparsification_method:
if (
self.method.sparsification_method
and self.method.sparsification_method != SparsificationMethod.consensus_ta
):
for tv_info in tvs:
kwargs = {}
if "gamma" in tv_info:
Expand All @@ -142,15 +161,17 @@ def execute(
if "epsilon" in tv_info:
kwargs["epsilon"] = tv_info["epsilon"]

tv_info["delta"] = sparsify(
tv_info["sparsified_delta"] = sparsify(
tv_info["delta"],
density=tv_info["density"],
method=self.method.sparsification_method,
rescale=self.rescale,
**kwargs,
)

deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
deltas = torch.stack([tv["sparsified_delta"] for tv in tvs], dim=0)
else:
deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
weights = torch.tensor(
[tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device
)
Expand Down Expand Up @@ -185,6 +206,20 @@ def execute(
lambda_factor = tvs[0]["lambda"]
mixed_delta *= lambda_factor

if (
self.method.sparsification_method == SparsificationMethod.consensus_ta
or self.method.sparsification_method == SparsificationMethod.consensus_ties
):
for tv_info in tvs:
tv_info["tall_mask"] = get_tall_mask(
tv_info["delta"],
tv_info["lambda"],
mixed_delta,
)
tall_masks = torch.stack([tv["tall_mask"] for tv in tvs], dim=0)
consensus_mask = tall_masks.sum(dim=0) >= tvs[0]["k"]
mixed_delta = mixed_delta * consensus_mask

return (base + mixed_delta).to(base.dtype)

def group_label(self) -> Optional[str]:
Expand Down
16 changes: 15 additions & 1 deletion mergekit/sparsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class SparsificationMethod(str, Enum):
random = "random"
magnitude_outliers = "magnitude_outliers"
rank_magnitude_sampling = "rank_magnitude_sampling"
consensus_ta = "consensus_ta"
consensus_ties = "consensus_ties"


def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor):
Expand Down Expand Up @@ -177,7 +179,10 @@ def sparsify(
rescale: bool = False,
epsilon: float = 0.15,
) -> torch.Tensor:
if method == SparsificationMethod.magnitude:
if (
method == SparsificationMethod.magnitude
or method == SparsificationMethod.consensus_ties
):
return magnitude(tensor, density=density, rescale=rescale)
elif method == SparsificationMethod.random:
return bernoulli(tensor, density=density, rescale=rescale)
Expand All @@ -187,3 +192,12 @@ def sparsify(
return rank_magnitude(tensor, density=density, rescale=rescale, epsilon=epsilon)
else:
raise NotImplementedError(method)


def get_tall_mask(
delta: torch.Tensor, # individual task vectors
lambda_factor: float, # hyper-parameter lambda for generating TALL masks
mixed_delta: torch.Tensor, # multi-task vector
):
mask = delta.abs() > lambda_factor * (mixed_delta - delta).abs()
return mask

0 comments on commit 8d1a10d

Please sign in to comment.