Skip to content

Commit

Permalink
fix avg_neighbors
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Aug 23, 2024
1 parent a159369 commit f09a08f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
6 changes: 5 additions & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def run(args: argparse.Namespace) -> None:
head_name="pt_head",
E0s="foundation",
statistics_file=args.statistics_file,
compute_avg_num_neighbors=False,
avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors,
)
collections = assemble_mp_data(args, tag, head_configs)
head_config_pt.collections = collections
Expand Down Expand Up @@ -264,6 +266,8 @@ def run(args: argparse.Namespace) -> None:
charges_key=args.charges_key,
keep_isolated_atoms=args.keep_isolated_atoms,
collections=collections,
avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors,
compute_avg_num_neighbors=False,
)
head_config_pt.collections = collections
logging.info(
Expand Down Expand Up @@ -473,7 +477,7 @@ def run(args: argparse.Namespace) -> None:
else:
args.avg_num_neighbors = avg_num_neighbors
else:
assert not any(head_config.avg_num_neighbors is None for head_config in head_configs), "Average number of neighbors must be provided in the configuration"
assert any(head_config.avg_num_neighbors is not None for head_config in head_configs), "Average number of neighbors must be provided in the configuration"
args.avg_num_neighbors = max([head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None])
logging.info(f"Average number of neighbors: {args.avg_num_neighbors}")

Expand Down
1 change: 0 additions & 1 deletion mace/data/hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __getitem__(self, index):
dipole=unpack_value(subgrp["dipole"][()]),
charges=unpack_value(subgrp["charges"][()]),
weight=unpack_value(subgrp["weight"][()]),
head=unpack_value(subgrp["head"][()]),
energy_weight=unpack_value(subgrp["energy_weight"][()]),
forces_weight=unpack_value(subgrp["forces_weight"][()]),
stress_weight=unpack_value(subgrp["stress_weight"][()]),
Expand Down

0 comments on commit f09a08f

Please sign in to comment.