From e25a1750cafd109139323e30b3e95125d8595caf Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 20:05:04 +0100 Subject: [PATCH] fix the lammps model backward comp --- mace/calculators/lammps_mace.py | 2 ++ mace/cli/create_lammps_model.py | 38 ++++++++++++++++++++++++++++++--- mace/modules/blocks.py | 7 +++--- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index 182dc283..408dfaa8 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -14,6 +14,8 @@ def __init__(self, model, **kwargs): self.register_buffer("atomic_numbers", model.atomic_numbers) self.register_buffer("r_max", model.r_max) self.register_buffer("num_interactions", model.num_interactions) + if not hasattr(model, "head"): + model.heads = [None] self.register_buffer( "head", torch.tensor( diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index c023f640..2d12d245 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -1,8 +1,6 @@ import argparse - import torch from e3nn.util import jit - from mace.calculators import LAMMPS_MACE @@ -24,12 +22,46 @@ def parse_args(): return parser.parse_args() +def select_head(model): + if hasattr(model, "heads"): + heads = model.heads + else: + heads = [None] + + if len(heads) == 1: + print(f"Only one head found in the model: {heads[0]}. Skipping selection.") + return heads[0] + + print("Available heads in the model:") + for i, head in enumerate(heads): + print(f"{i + 1}: {head}") + + # Ask the user to select a head + selected = input( + f"Select a head by number (default: {len(heads)}, press Enter to skip): " + ) + + if selected.isdigit() and 1 <= int(selected) <= len(heads): + return heads[int(selected) - 1] + elif selected == "": + print("No head selected. Proceeding without specifying a head.") + return None + else: + print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") + return heads[-1] + + def main(): args = parse_args() model_path = args.model_path # takes model name as command-line input - head = args.head model = torch.load(model_path) model = model.double().to("cpu") + + if args.head is None: + head = select_head(model) + else: + head = args.head + lammps_model = ( LAMMPS_MACE(model, head=head) if head is not None else LAMMPS_MACE(model) ) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 48a9d22c..34539b0b 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -81,8 +81,9 @@ def forward( self, x: torch.Tensor, heads: Optional[torch.Tensor] = None ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.non_linearity(self.linear_1(x)) - if hasattr(self, "num_heads") and self.num_heads > 1 and heads is not None: - x = mask_head(x, heads, self.num_heads) + if hasattr(self, "num_heads"): + if self.num_heads > 1 and heads is not None: + x = mask_head(x, heads, self.num_heads) return self.linear_2(x) # [n_nodes, len(heads)] @@ -620,7 +621,7 @@ def _setup(self) -> None: input_dim = self.edge_feats_irreps.num_irreps self.conv_tp_weights = nn.FullyConnectedNet( [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, + torch.nn.functional.silu, # gate ) # Linear