Skip to content

Commit

Permalink
add selection head script to cli
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 12, 2024
1 parent 1300ad1 commit 68149a7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
File renamed without changes.
33 changes: 33 additions & 0 deletions mace/cli/select_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from argparse import ArgumentParser

import torch

from mace.tools.scripts_utils import remove_pt_head


def main():
parser = ArgumentParser()
parser.add_argument(
"--head_name",
"-n",
help="name of the head to extract",
default=None,
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model_single = remove_pt_head(model, args.head_name)
torch.save(model_single, args.output_file)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ console_scripts =
mace_run_train = mace.cli.run_train:main
mace_prepare_data = mace.cli.preprocess_data:main
mace_finetuning = mace.cli.fine_tuning_select:main
mace_convert_dev = mace.cli.convert_dev:main
mace_convert_device = mace.cli.convert_device:main
mace_select_head = mace.cli.select_head:main

[options.extras_require]
wandb = wandb
Expand Down

0 comments on commit 68149a7

Please sign in to comment.