Skip to content

Commit

Permalink
Code simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
metric-space committed Jul 30, 2024
1 parent aa64f01 commit c57dd5e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
13 changes: 5 additions & 8 deletions mergekit/scripts/ABM/activations_based_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


@click.command("mergekit-activation-based-merge")
@click.argument("model_path", type=str)
@click.argument("secondary_model_path", type=str)
@click.argument("model_path", type=str, help="Path to the anchor model")
@click.argument("secondary_model_path", type=str, help="Path to the secondary model")
@click.argument("merge_unmerge_directory", type=str)
@click.option("--out-path", "-o", required=True, type=str, help="Output model path")
@click.option(
Expand Down Expand Up @@ -121,8 +121,8 @@ def main(

if merge_matrix is not None:
if weight_info.is_embed:
w = (merge_matrix[0] @ w.T).T
w2 = (merge_matrix[1] @ w2.T).T
w = w @ merge_matrix[0].T
w2 = w2 @ merge_matrix[1].T
else:
w = merge_matrix[0] @ w
w2 = merge_matrix[1] @ w2
Expand Down Expand Up @@ -151,10 +151,7 @@ def main(
)

# average weights and save them
if merge_matrix:
w = w + w2
else:
w = (w + w2) / 2
w = (w + w2) / 2
writer.save_tensor(weight_info.name, w)
writer.finalize()

Expand Down
19 changes: 6 additions & 13 deletions mergekit/scripts/ABM/extract_permutation_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,10 @@ def match_tensors_permute(
new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)]
mats.append(new_mat.T)

unmerge_mats = mats
unmerge = torch.cat(mats, dim=0)
merge = unmerge.clone().T

unmerge = torch.cat(unmerge_mats, dim=0)

merge = torch.cat(mats, dim=0)
merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5)

return merge.T, unmerge
return merge, unmerge


def match_tensors_permute_MHA(
Expand Down Expand Up @@ -111,13 +107,10 @@ def match_tensors_permute_MHA(
]
mats.append(new_mat.T)

unmerge_mats = mats

unmerge = torch.cat(unmerge_mats, dim=0)
merge = torch.cat(mats, dim=0)
merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5)
unmerge = torch.cat(mats, dim=0)
merge = unmerge.clone().T

return merge.T, unmerge
return merge, unmerge


@click.command("mergekit-abm-extract-permutations")
Expand Down

0 comments on commit c57dd5e

Please sign in to comment.