Skip to content

Commit

Permalink
Merge pull request ACEsuit#642 from beckobert/eval_configs_multihead
Browse files Browse the repository at this point in the history
Fix multihead prediction for eval_configs.py
  • Loading branch information
ilyes319 authored Oct 16, 2024
2 parents 1cddd99 + 1b7e369 commit 3867c26
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def parse_args() -> argparse.Namespace:
type=str,
default="MACE_",
)
parser.add_argument(
"--head",
help="Model head used for evaluation",
type=str,
required=False,
default=None
)
return parser.parse_args()


Expand All @@ -76,14 +83,22 @@ def run(args: argparse.Namespace) -> None:

# Load data and prepare input
atoms_list = ase.io.read(args.configs, index=":")
if args.head is not None:
for atoms in atoms_list:
atoms.info["head"] = args.head
configs = [data.config_from_atoms(atoms) for atoms in atoms_list]

z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])

try:
heads = model.heads
except AttributeError:
heads = None

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=z_table, cutoff=float(model.r_max)
config, z_table=z_table, cutoff=float(model.r_max), heads=heads
)
for config in configs
],
Expand Down

0 comments on commit 3867c26

Please sign in to comment.