Skip to content

Commit

Permalink
add head selection in lammps
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Aug 27, 2024
1 parent 5e70bfe commit 5017b84
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
11 changes: 10 additions & 1 deletion mace/calculators/lammps_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@

@compile_mode("script")
class LAMMPS_MACE(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions)
self.register_buffer(
"head",
torch.tensor(
self.model.heads.index(kwargs.get("head", self.model.heads[0])),
dtype=torch.long,
),
)

for param in self.model.parameters():
param.requires_grad = False

Expand All @@ -27,6 +35,7 @@ def forward(
compute_displacement = False
if compute_virials:
compute_displacement = True
data["head"] = self.head
out = self.model(
data,
training=False,
Expand Down
27 changes: 23 additions & 4 deletions mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,34 @@
from e3nn.util import jit

from mace.calculators import LAMMPS_MACE
import argparse


def main():
assert len(sys.argv) == 2, f"Usage: {sys.argv[0]} model_path"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the model to be converted to LAMMPS",
)
parser.add_argument(
"--head",
type=str,
nargs="?",
help="Head of the model to be converted to LAMMPS",
default="default",
)
return parser.parse_args()


model_path = sys.argv[1] # takes model name as command-line input
def main():
args = parse_args()
model_path = args.model_path # takes model name as command-line input
head = args.head
model = torch.load(model_path)
model = model.double().to("cpu")
lammps_model = LAMMPS_MACE(model)
lammps_model = LAMMPS_MACE(model, head=head)
lammps_model_compiled = jit.compile(lammps_model)
lammps_model_compiled.save(model_path + "-lammps.pt")

Expand Down

0 comments on commit 5017b84

Please sign in to comment.