diff --git a/mace/cli/convert_dev.py b/mace/cli/convert_device.py similarity index 100% rename from mace/cli/convert_dev.py rename to mace/cli/convert_device.py diff --git a/mace/cli/select_head.py b/mace/cli/select_head.py new file mode 100644 index 00000000..a1e27229 --- /dev/null +++ b/mace/cli/select_head.py @@ -0,0 +1,33 @@ +from argparse import ArgumentParser + +import torch + +from mace.tools.scripts_utils import remove_pt_head + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--head_name", + "-n", + help="name of the head to extract", + default=None, + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file) + model_single = remove_pt_head(model, args.head_name) + torch.save(model_single, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index 139f914e..76467fda 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,8 @@ console_scripts = mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main mace_finetuning = mace.cli.fine_tuning_select:main - mace_convert_dev = mace.cli.convert_dev:main + mace_convert_device = mace.cli.convert_device:main + mace_select_head = mace.cli.select_head:main [options.extras_require] wandb = wandb