Skip to content

Commit

Permalink
formatting files
Browse files Browse the repository at this point in the history
  • Loading branch information
Tej-Deep committed Jul 16, 2024
1 parent 05253c4 commit 0031a40
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
4 changes: 2 additions & 2 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def get(method: str) -> MergeMethod:
)
elif method == "model_stock":
return ModelStockMerge()

elif method == "della":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.rank_magnitude_sampling,
default_normalize=True,
default_rescale=True,
)

elif method == "della_linear":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
Expand Down
9 changes: 6 additions & 3 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def execute(
kwargs = {}
if "gamma" in tv_info:
kwargs["gamma"] = tv_info["gamma"]

if "epsilon" in tv_info:
kwargs["epsilon"] = tv_info["epsilon"]

Expand Down Expand Up @@ -174,8 +174,11 @@ def execute(

if self.normalize:
mixed_delta /= divisor

if self.method.sparsification_method == SparsificationMethod.rank_magnitude_sampling:

if (
self.method.sparsification_method
== SparsificationMethod.rank_magnitude_sampling
):
lambda_factor = tvs[0]["lambda"]
mixed_delta *= lambda_factor

Expand Down
32 changes: 21 additions & 11 deletions mergekit/sparsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,37 +119,47 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens

return res.to(tensor.dtype)


def rank_magnitude(
tensor: torch.Tensor, density: float, rescale: bool = True, epsilon: float = 0.05
) -> torch.Tensor:
if density >= 1:
return tensor

if density <= epsilon or density>=(1-epsilon):
raise ValueError(f"Error: density +- epsilon must be in the range (0, 1). density + epsilon = {density+epsilon}, density - epsilon = {density-epsilon}")
if density <= epsilon or density >= (1 - epsilon):
raise ValueError(
f"Error: density +- epsilon must be in the range (0, 1). density + epsilon = {density+epsilon}, density - epsilon = {density-epsilon}"
)

if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16:
work_dtype = tensor.dtype
else:
work_dtype = torch.float32

if len(tensor.shape)<2:
if len(tensor.shape) < 2:
tensor = tensor.unsqueeze(0)

# Get Rank matrix for the delta values
tensor_abs = torch.abs(tensor)

sorted_indices = torch.argsort(tensor_abs, dim=1, descending=False)

ranking_tensor = torch.zeros_like(tensor_abs, dtype=work_dtype)
for i in range(tensor_abs.size(0)):
ranking_tensor[i][sorted_indices[i]] = torch.arange(1, tensor.size(1) + 1, dtype= work_dtype).to(tensor.device)

ranking_tensor[i][sorted_indices[i]] = torch.arange(
1, tensor.size(1) + 1, dtype=work_dtype
).to(tensor.device)

# Normalise rank matrix to the probability range to density +- epsilon
range_vals = ranking_tensor.max(dim = 1, keepdim=True).values - ranking_tensor.min(dim = 1, keepdim=True).values
norm_metrics = (ranking_tensor - ranking_tensor.min(dim = 1, keepdim=True).values)/(range_vals)
final_probabilities = (density-epsilon) + norm_metrics * (2*epsilon)

range_vals = (
ranking_tensor.max(dim=1, keepdim=True).values
- ranking_tensor.min(dim=1, keepdim=True).values
)
norm_metrics = (ranking_tensor - ranking_tensor.min(dim=1, keepdim=True).values) / (
range_vals
)
final_probabilities = (density - epsilon) + norm_metrics * (2 * epsilon)

mask = torch.bernoulli(final_probabilities).to(work_dtype)
res = tensor.to(work_dtype) * mask

Expand Down

0 comments on commit 0031a40

Please sign in to comment.