diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index 8ad7e984..ed9f0e2c 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -8,12 +8,20 @@ @compile_mode("script") class LAMMPS_MACE(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, **kwargs): super().__init__() self.model = model self.register_buffer("atomic_numbers", model.atomic_numbers) self.register_buffer("r_max", model.r_max) self.register_buffer("num_interactions", model.num_interactions) + self.register_buffer( + "head", + torch.tensor( + self.model.heads.index(kwargs.get("head", self.model.heads[0])), + dtype=torch.long, + ), + ) + for param in self.model.parameters(): param.requires_grad = False @@ -27,6 +35,7 @@ def forward( compute_displacement = False if compute_virials: compute_displacement = True + data["head"] = self.head out = self.model( data, training=False, diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 4cae618f..858b708d 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -4,15 +4,34 @@ from e3nn.util import jit from mace.calculators import LAMMPS_MACE +import argparse -def main(): - assert len(sys.argv) == 2, f"Usage: {sys.argv[0]} model_path" +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the model to be converted to LAMMPS", + ) + parser.add_argument( + "--head", + type=str, + nargs="?", + help="Head of the model to be converted to LAMMPS", + default="default", + ) + return parser.parse_args() + - model_path = sys.argv[1] # takes model name as command-line input +def main(): + args = parse_args() + model_path = args.model_path # takes model name as command-line input + head = args.head model = torch.load(model_path) model = model.double().to("cpu") - lammps_model = LAMMPS_MACE(model) + lammps_model = LAMMPS_MACE(model, head=head) lammps_model_compiled = jit.compile(lammps_model) lammps_model_compiled.save(model_path + "-lammps.pt")