Skip to content

Commit

Permalink
fix the lammps model backward comp
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Aug 27, 2024
1 parent 529b261 commit e25a175
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
2 changes: 2 additions & 0 deletions mace/calculators/lammps_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def __init__(self, model, **kwargs):
self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions)
if not hasattr(model, "head"):
model.heads = [None]
self.register_buffer(
"head",
torch.tensor(
Expand Down
38 changes: 35 additions & 3 deletions mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse

import torch
from e3nn.util import jit

from mace.calculators import LAMMPS_MACE


Expand All @@ -24,12 +22,46 @@ def parse_args():
return parser.parse_args()


def select_head(model):
if hasattr(model, "heads"):
heads = model.heads
else:
heads = [None]

if len(heads) == 1:
print(f"Only one head found in the model: {heads[0]}. Skipping selection.")
return heads[0]

print("Available heads in the model:")
for i, head in enumerate(heads):
print(f"{i + 1}: {head}")

# Ask the user to select a head
selected = input(
f"Select a head by number (default: {len(heads)}, press Enter to skip): "
)

if selected.isdigit() and 1 <= int(selected) <= len(heads):
return heads[int(selected) - 1]
elif selected == "":
print("No head selected. Proceeding without specifying a head.")
return None
else:
print(f"No valid selection made. Defaulting to the last head: {heads[-1]}")
return heads[-1]


def main():
args = parse_args()
model_path = args.model_path # takes model name as command-line input
head = args.head
model = torch.load(model_path)
model = model.double().to("cpu")

if args.head is None:
head = select_head(model)
else:
head = args.head

lammps_model = (
LAMMPS_MACE(model, head=head) if head is not None else LAMMPS_MACE(model)
)
Expand Down
7 changes: 4 additions & 3 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def forward(
self, x: torch.Tensor, heads: Optional[torch.Tensor] = None
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.non_linearity(self.linear_1(x))
if hasattr(self, "num_heads") and self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
if hasattr(self, "num_heads"):
if self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
return self.linear_2(x) # [n_nodes, len(heads)]


Expand Down Expand Up @@ -620,7 +621,7 @@ def _setup(self) -> None:
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
torch.nn.functional.silu, # gate
)

# Linear
Expand Down

0 comments on commit e25a175

Please sign in to comment.