Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into calc-from-module
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Sep 25, 2024
2 parents e06e94e + 22a2e3e commit 10302b1
Show file tree
Hide file tree
Showing 13 changed files with 487 additions and 173 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include mace/py.typed
3 changes: 2 additions & 1 deletion mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def run(args: argparse.Namespace):

# Data preparation
collections, atomic_energies_dict = get_dataset_from_xyz(
work_dir=args.work_dir,
train_path=args.train_file,
valid_path=args.valid_file,
valid_fraction=args.valid_fraction,
Expand Down Expand Up @@ -211,7 +212,7 @@ def run(args: argparse.Namespace):
atomic_energies: np.ndarray = np.array(
[atomic_energies_dict[z] for z in z_table.zs]
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")
logging.info(f"Atomic Energies: {atomic_energies.tolist()}")
_inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process]
avg_num_neighbors, mean, std=pool_compute_stats(_inputs)
logging.info(f"Average number of neighbors: {avg_num_neighbors}")
Expand Down
172 changes: 120 additions & 52 deletions mace/cli/run_train.py

Large diffs are not rendered by default.

69 changes: 41 additions & 28 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Configuration:


def random_train_valid_split(
items: Sequence, valid_fraction: float, seed: int
items: Sequence, valid_fraction: float, seed: int, work_dir: str
) -> Tuple[List, List]:
assert 0.0 < valid_fraction < 1.0

Expand All @@ -63,6 +63,19 @@ def random_train_valid_split(
indices = list(range(size))
rng = np.random.default_rng(seed)
rng.shuffle(indices)
if len(indices[train_size:]) < 10:
logging.info(
f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}"
)
else:
# Save indices to file
with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f:
for index in indices[train_size:]:
f.write(f"{index}\n")

logging.info(
f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt"
)

return (
[items[i] for i in indices[:train_size]],
Expand All @@ -72,12 +85,12 @@ def random_train_valid_split(

def config_from_atoms_list(
atoms_list: List[ase.Atoms],
energy_key="energy",
forces_key="forces",
stress_key="stress",
virials_key="virials",
dipole_key="dipole",
charges_key="charges",
energy_key="REF_energy",
forces_key="REF_forces",
stress_key="REF_stress",
virials_key="REF_virials",
dipole_key="REF_dipole",
charges_key="REF_charges",
config_type_weights: Dict[str, float] = None,
) -> Configurations:
"""Convert list of ase.Atoms into Configurations"""
Expand All @@ -103,12 +116,12 @@ def config_from_atoms_list(

def config_from_atoms(
atoms: ase.Atoms,
energy_key="energy",
forces_key="forces",
stress_key="stress",
virials_key="virials",
dipole_key="dipole",
charges_key="charges",
energy_key="REF_energy",
forces_key="REF_forces",
stress_key="REF_stress",
virials_key="REF_virials",
dipole_key="REF_dipole",
charges_key="REF_charges",
config_type_weights: Dict[str, float] = None,
) -> Configuration:
"""Convert ase.Atoms to Configuration"""
Expand Down Expand Up @@ -192,41 +205,41 @@ def test_config_types(
def load_from_xyz(
file_path: str,
config_type_weights: Dict,
energy_key: str = "energy",
forces_key: str = "forces",
stress_key: str = "stress",
virials_key: str = "virials",
dipole_key: str = "dipole",
charges_key: str = "charges",
energy_key: str = "REF_energy",
forces_key: str = "REF_forces",
stress_key: str = "REF_stress",
virials_key: str = "REF_virials",
dipole_key: str = "REF_dipole",
charges_key: str = "REF_charges",
extract_atomic_energies: bool = False,
keep_isolated_atoms: bool = False,
) -> Tuple[Dict[int, float], Configurations]:
atoms_list = ase.io.read(file_path, index=":")
if energy_key == "energy":
logging.info(
"Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_energy'. You need to use --energy_key='REF_energy', to tell the key name chosen."
logging.warning(
"Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name."
)
energy_key = "REF_energy"
for atoms in atoms_list:
try:
atoms.info["REF_energy"] = atoms.get_potential_energy()
except Exception as e: # pylint: disable=W0703
logging.warning(f"Failed to extract energy: {e}")
logging.error(f"Failed to extract energy: {e}")
atoms.info["REF_energy"] = None
if forces_key == "forces":
logging.info(
"Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_forces'. You need to use --forces_key='REF_forces', to tell the key name chosen."
logging.warning(
"Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name."
)
forces_key = "REF_forces"
for atoms in atoms_list:
try:
atoms.arrays["REF_forces"] = atoms.get_forces()
except Exception as e: # pylint: disable=W0703
logging.warning(f"Failed to extract forces: {e}")
logging.error(f"Failed to extract forces: {e}")
atoms.arrays["REF_forces"] = None
if stress_key == "stress":
logging.info(
"Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_stress'. You need to use --stress_key='REF_stress', to tell the key name chosen."
logging.warning(
"Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name."
)
stress_key = "REF_stress"
for atoms in atoms_list:
Expand Down Expand Up @@ -298,7 +311,7 @@ def compute_average_E0s(
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = E0s[i]
except np.linalg.LinAlgError:
logging.warning(
logging.error(
"Failed to compute E0s using least squares regression, using the same for all atoms"
)
atomic_energies_dict = {}
Expand Down
1 change: 1 addition & 0 deletions mace/py.typed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 2 additions & 0 deletions mace/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser
from .arg_parser_tools import check_args
from .cg import U_matrix_real
from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState
from .finetuning_utils import load_foundations
Expand Down Expand Up @@ -39,6 +40,7 @@
"to_numpy",
"to_one_hot",
"build_default_arg_parser",
"check_args",
"set_seeds",
"init_device",
"setup_logger",
Expand Down
Loading

0 comments on commit 10302b1

Please sign in to comment.