diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index dcd2b8e5..c15014ad 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -182,10 +182,18 @@ def __init__( [int(z) for z in self.models[0].atomic_numbers] ) self.charges_key = charges_key + try: - self.heads = self.models[0].heads + self.available_heads = self.models[0].heads except AttributeError: - self.heads = ["Default"] + self.available_heads = ["Default"] + self.head = kwargs.get("head", "Default") + assert ( + self.head in self.available_heads + ), f"specified head {self.head}, but model available model heads are {self.available_heads}" + + print("using head", self.head, "out of", self.available_heads) + model_dtype = get_model_dtype(self.models[0]) if default_dtype == "": print( @@ -235,11 +243,19 @@ def _create_result_tensors( return dict_of_tensors def _atoms_to_batch(self, atoms): - config = data.config_from_atoms(atoms, charges_key=self.charges_key) + keyspec = data.KeySpecification( + info_keys={}, arrays_keys={"charges": self.charges_key} + ) + config = data.config_from_atoms( + atoms, key_specification=keyspec, head_name=self.head + ) data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.available_heads, ) ], batch_size=1, diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index ef9f1343..ea21a0d7 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -17,6 +17,7 @@ import tqdm from mace import data, tools +from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.data.utils import save_configurations_as_HDF5 from mace.modules import compute_statistics from mace.tools import torch_geometric @@ -144,6 +145,10 @@ def run(args: argparse.Namespace): new hdf5 file that is ready for training with on-the-fly dataloading """ + # currently support only command line property_key syntax + args.key_specification = KeySpecification() + update_keyspec_from_kwargs(args.key_specification, vars(args)) + # Setup tools.set_seeds(args.seed) random.seed(args.seed) @@ -177,12 +182,8 @@ def run(args: argparse.Namespace): config_type_weights=config_type_weights, test_path=args.test_file, seed=args.seed, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, + key_specification=args.key_specification, + head_name=None, ) # Atomic number table diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 7f1a5e74..9a754ffc 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -24,6 +24,7 @@ import mace from mace import data, tools from mace.calculators.foundations_models import mace_mp, mace_off +from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.tools import torch_geometric from mace.tools.model_script_utils import configure_model from mace.tools.multihead_tools import ( @@ -70,6 +71,10 @@ def run(args: argparse.Namespace) -> None: tag = tools.get_tag(name=args.name, seed=args.seed) args, input_log_messages = tools.check_args(args) + # default keyspec to update using heads dictionary + args.key_specification = KeySpecification() + update_keyspec_from_kwargs(args.key_specification, vars(args)) + if args.device == "xpu": try: import intel_extension_for_pytorch as ipex @@ -153,12 +158,26 @@ def run(args: argparse.Namespace) -> None: if args.heads is not None: args.heads = ast.literal_eval(args.heads) + for _, head_dict in args.heads.items(): + # priority is global args < head property_key values < head info_keys+arrays_keys + head_keyspec = deepcopy(args.key_specification) + update_keyspec_from_kwargs(head_keyspec, head_dict) + head_keyspec.update( + info_keys=head_dict.get("info_keys", {}), + arrays_keys=head_dict.get("arrays_keys", {}), + ) + head_dict["key_specification"] = head_keyspec else: args.heads = prepare_default_head(args) logging.info("===========LOADING INPUT DATA===========") heads = list(args.heads.keys()) logging.info(f"Using heads: {heads}") + logging.info("Using the key specifications to parse data:") + for name, head_dict in args.heads.items(): + head_keyspec = head_dict["key_specification"] + logging.info(f"{name}: {head_keyspec}") + head_configs: List[HeadConfig] = [] for head, head_args in args.heads.items(): logging.info(f"============= Processing head {head} ===========") @@ -187,7 +206,6 @@ def run(args: argparse.Namespace) -> None: head_config.atomic_energies_dict = ast.literal_eval( statistics["atomic_energies"] ) - # Data preparation if check_path_ase_read(head_config.train_file): if head_config.valid_file is not None: @@ -205,12 +223,7 @@ def run(args: argparse.Namespace) -> None: config_type_weights=config_type_weights, test_path=head_config.test_file, seed=args.seed, - energy_key=head_config.energy_key, - forces_key=head_config.forces_key, - stress_key=head_config.stress_key, - virials_key=head_config.virials_key, - dipole_key=head_config.dipole_key, - charges_key=head_config.charges_key, + key_specification=head_config.key_specification, head_name=head_config.head_name, keep_isolated_atoms=head_config.keep_isolated_atoms, ) @@ -251,14 +264,21 @@ def run(args: argparse.Namespace) -> None: "Using foundation model for multiheads finetuning with Materials Project data" ) heads = list(dict.fromkeys(["pt_head"] + heads)) + mp_keyspec = KeySpecification() + update_keyspec_from_kwargs(mp_keyspec, vars(args)) + mp_keyspec.update( + info_keys={"energy": "energy", "stress": "stress"}, + arrays_keys={"forces": "forces"}, + ) head_config_pt = HeadConfig( head_name="pt_head", E0s="foundation", statistics_file=args.statistics_file, + key_specification=mp_keyspec, compute_avg_num_neighbors=False, avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, ) - collections = assemble_mp_data(args, tag, head_configs) + collections = assemble_mp_data(args, tag, head_configs, head_config_pt) head_config_pt.collections = collections head_config_pt.train_file = f"mp_finetuning-{tag}.xyz" head_configs.append(head_config_pt) @@ -275,12 +295,7 @@ def run(args: argparse.Namespace) -> None: config_type_weights=None, test_path=None, seed=args.seed, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, + key_specification=args.key_specification, head_name="pt_head", keep_isolated_atoms=args.keep_isolated_atoms, ) @@ -292,12 +307,7 @@ def run(args: argparse.Namespace) -> None: statistics_file=args.statistics_file, valid_fraction=args.valid_fraction, config_type_weights=None, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, + key_specification=args.key_specification, keep_isolated_atoms=args.keep_isolated_atoms, collections=collections, avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, diff --git a/mace/data/__init__.py b/mace/data/__init__.py index c10a3698..07999d39 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -4,6 +4,7 @@ from .utils import ( Configuration, Configurations, + KeySpecification, compute_average_E0s, config_from_atoms, config_from_atoms_list, @@ -13,6 +14,7 @@ save_configurations_as_HDF5, save_dataset_as_HDF5, test_config_types, + update_keyspec_from_kwargs, ) __all__ = [ @@ -31,4 +33,6 @@ "dataset_from_sharded_hdf5", "save_AtomicData_to_HDF5", "save_configurations_as_HDF5", + "KeySpecification", + "update_keyspec_from_kwargs", ] diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index cb4edd94..14dd39df 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -4,6 +4,7 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +from copy import deepcopy from typing import Optional, Sequence import torch.utils.data @@ -42,6 +43,8 @@ class AtomicData(torch_geometric.data.Data): forces_weight: torch.Tensor stress_weight: torch.Tensor virials_weight: torch.Tensor + dipole_weight: torch.Tensor + charges_weight: torch.Tensor def __init__( self, @@ -57,6 +60,8 @@ def __init__( forces_weight: Optional[torch.Tensor], # [,] stress_weight: Optional[torch.Tensor], # [,] virials_weight: Optional[torch.Tensor], # [,] + dipole_weight: Optional[torch.Tensor], # [,] + charges_weight: Optional[torch.Tensor], # [,] forces: Optional[torch.Tensor], # [n_nodes, 3] energy: Optional[torch.Tensor], # [, ] stress: Optional[torch.Tensor], # [1,3,3] @@ -78,6 +83,8 @@ def __init__( assert forces_weight is None or len(forces_weight.shape) == 0 assert stress_weight is None or len(stress_weight.shape) == 0 assert virials_weight is None or len(virials_weight.shape) == 0 + assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight + assert charges_weight is None or len(charges_weight.shape) == 0 assert cell is None or cell.shape == (3, 3) assert forces is None or forces.shape == (num_nodes, 3) assert energy is None or len(energy.shape) == 0 @@ -100,6 +107,8 @@ def __init__( "forces_weight": forces_weight, "stress_weight": stress_weight, "virials_weight": virials_weight, + "dipole_weight": dipole_weight, + "charges_weight": charges_weight, "forces": forces, "energy": energy, "stress": stress, @@ -116,11 +125,15 @@ def from_config( z_table: AtomicNumberTable, cutoff: float, heads: Optional[list] = None, + **kwargs, # pylint: disable=unused-argument ) -> "AtomicData": if heads is None: - heads = ["default"] + heads = ["Default"] edge_index, shifts, unit_shifts, cell = get_neighborhood( - positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell + positions=config.positions, + cutoff=cutoff, + pbc=deepcopy(config.pbc), + cell=deepcopy(config.cell), ) indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) one_hot = to_one_hot( @@ -140,69 +153,113 @@ def from_config( ).view(3, 3) ) + num_atoms = len(config.atomic_numbers) + weight = ( torch.tensor(config.weight, dtype=torch.get_default_dtype()) if config.weight is not None - else 1 + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) energy_weight = ( - torch.tensor(config.energy_weight, dtype=torch.get_default_dtype()) - if config.energy_weight is not None - else 1 + torch.tensor( + config.property_weights.get("energy"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("energy") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) forces_weight = ( - torch.tensor(config.forces_weight, dtype=torch.get_default_dtype()) - if config.forces_weight is not None - else 1 + torch.tensor( + config.property_weights.get("forces"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("forces") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) stress_weight = ( - torch.tensor(config.stress_weight, dtype=torch.get_default_dtype()) - if config.stress_weight is not None - else 1 + torch.tensor( + config.property_weights.get("stress"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("stress") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) virials_weight = ( - torch.tensor(config.virials_weight, dtype=torch.get_default_dtype()) - if config.virials_weight is not None - else 1 + torch.tensor( + config.property_weights.get("virials"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("virials") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + dipole_weight = ( + torch.tensor( + config.property_weights.get("dipole"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("dipole") is not None + else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) + ) + if len(dipole_weight.shape) == 0: + dipole_weight = dipole_weight * torch.tensor( + [[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype() + ) + elif len(dipole_weight.shape) == 1: + dipole_weight = dipole_weight.unsqueeze(0) + + charges_weight = ( + torch.tensor( + config.property_weights.get("charges"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("charges") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) forces = ( - torch.tensor(config.forces, dtype=torch.get_default_dtype()) - if config.forces is not None - else None + torch.tensor( + config.properties.get("forces"), dtype=torch.get_default_dtype() + ) + if config.properties.get("forces") is not None + else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype()) ) energy = ( - torch.tensor(config.energy, dtype=torch.get_default_dtype()) - if config.energy is not None - else None + torch.tensor( + config.properties.get("energy"), dtype=torch.get_default_dtype() + ) + if config.properties.get("energy") is not None + else torch.tensor(0.0, dtype=torch.get_default_dtype()) ) stress = ( voigt_to_matrix( - torch.tensor(config.stress, dtype=torch.get_default_dtype()) + torch.tensor( + config.properties.get("stress"), dtype=torch.get_default_dtype() + ) ).unsqueeze(0) - if config.stress is not None - else None + if config.properties.get("stress") is not None + else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) ) virials = ( voigt_to_matrix( - torch.tensor(config.virials, dtype=torch.get_default_dtype()) + torch.tensor( + config.properties.get("virials"), dtype=torch.get_default_dtype() + ) ).unsqueeze(0) - if config.virials is not None - else None + if config.properties.get("virials") is not None + else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) ) dipole = ( - torch.tensor(config.dipole, dtype=torch.get_default_dtype()).unsqueeze(0) - if config.dipole is not None - else None + torch.tensor( + config.properties.get("dipole"), dtype=torch.get_default_dtype() + ).unsqueeze(0) + if config.properties.get("dipole") is not None + else torch.zeros(1, 3, dtype=torch.get_default_dtype()) ) charges = ( - torch.tensor(config.charges, dtype=torch.get_default_dtype()) - if config.charges is not None - else None + torch.tensor( + config.properties.get("charges"), dtype=torch.get_default_dtype() + ) + if config.properties.get("charges") is not None + else torch.zeros(num_atoms, dtype=torch.get_default_dtype()) ) return cls( @@ -218,6 +275,8 @@ def from_config( forces_weight=forces_weight, stress_weight=stress_weight, virials_weight=virials_weight, + dipole_weight=dipole_weight, + charges_weight=charges_weight, forces=forces, energy=energy, stress=stress, diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 477ccd3f..f7755175 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -10,7 +10,9 @@ class HDF5Dataset(Dataset): - def __init__(self, file_path, r_max, z_table, **kwargs): + def __init__( + self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs + ): super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments self.file_path = file_path self._file = None @@ -19,6 +21,7 @@ def __init__(self, file_path, r_max, z_table, **kwargs): self.length = len(self.file.keys()) * self.batch_size self.r_max = r_max self.z_table = z_table + self.atomic_dataclass = atomic_dataclass try: self.drop_last = bool(self.file.attrs["drop_last"]) except KeyError: @@ -48,31 +51,32 @@ def __getitem__(self, index): config_index = index % self.batch_size grp = self.file["config_batch_" + str(batch_index)] subgrp = grp["config_" + str(config_index)] + + properties = {} + property_weights = {} + for key in subgrp["properties"]: + properties[key] = unpack_value(subgrp["properties"][key][()]) + for key in subgrp["property_weights"]: + property_weights[key] = unpack_value(subgrp["property_weights"][key][()]) + config = Configuration( atomic_numbers=subgrp["atomic_numbers"][()], positions=subgrp["positions"][()], - energy=unpack_value(subgrp["energy"][()]), - forces=unpack_value(subgrp["forces"][()]), - stress=unpack_value(subgrp["stress"][()]), - virials=unpack_value(subgrp["virials"][()]), - dipole=unpack_value(subgrp["dipole"][()]), - charges=unpack_value(subgrp["charges"][()]), + properties=properties, weight=unpack_value(subgrp["weight"][()]), - energy_weight=unpack_value(subgrp["energy_weight"][()]), - forces_weight=unpack_value(subgrp["forces_weight"][()]), - stress_weight=unpack_value(subgrp["stress_weight"][()]), - virials_weight=unpack_value(subgrp["virials_weight"][()]), + property_weights=property_weights, config_type=unpack_value(subgrp["config_type"][()]), pbc=unpack_value(subgrp["pbc"][()]), cell=unpack_value(subgrp["cell"][()]), ) if config.head is None: config.head = self.kwargs.get("head") - atomic_data = AtomicData.from_config( + atomic_data = self.atomic_dataclass.from_config( config, z_table=self.z_table, cutoff=self.r_max, heads=self.kwargs.get("heads", ["Default"]), + **self.kwargs, ) return atomic_data diff --git a/mace/data/neighborhood.py b/mace/data/neighborhood.py index 21296fa6..03728969 100644 --- a/mace/data/neighborhood.py +++ b/mace/data/neighborhood.py @@ -10,7 +10,7 @@ def get_neighborhood( pbc: Optional[Tuple[bool, bool, bool]] = None, cell: Optional[np.ndarray] = None, # [3, 3] true_self_interaction=False, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if pbc is None: pbc = (False, False, False) diff --git a/mace/data/utils.py b/mace/data/utils.py index bb8e5448..a552f1a8 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -5,8 +5,8 @@ ########################################################################################### import logging -from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Tuple import ase.data import ase.io @@ -15,12 +15,7 @@ from mace.tools import AtomicNumberTable -Vector = np.ndarray # [3,] Positions = np.ndarray # [..., 3] -Forces = np.ndarray # [..., 3] -Stress = np.ndarray # [6, ], [3,3], [9, ] -Virials = np.ndarray # [6, ], [3,3], [9, ] -Charges = np.ndarray # [..., 1] Cell = np.ndarray # [3,3] Pbc = tuple # (3,) @@ -28,24 +23,51 @@ DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} +@dataclass +class KeySpecification: + info_keys: Dict[str, str] = field(default_factory=dict) + arrays_keys: Dict[str, str] = field(default_factory=dict) + + def update( + self, + info_keys: Optional[Dict[str, str]] = None, + arrays_keys: Optional[Dict[str, str]] = None, + ): + if info_keys is not None: + self.info_keys.update(info_keys) + if arrays_keys is not None: + self.arrays_keys.update(arrays_keys) + return self + + +def update_keyspec_from_kwargs( + keyspec: KeySpecification, keydict: Dict[str, str] +) -> KeySpecification: + # convert command line style property_key arguments into a keyspec + infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] + arrays = ["forces_key", "charges_key"] + info_keys = {} + arrays_keys = {} + for key in infos: + if key in keydict: + info_keys[key[:-4]] = keydict[key] + for key in arrays: + if key in keydict: + arrays_keys[key[:-4]] = keydict[key] + keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys) + return keyspec + + @dataclass class Configuration: atomic_numbers: np.ndarray positions: Positions # Angstrom - energy: Optional[float] = None # eV - forces: Optional[Forces] = None # eV/Angstrom - stress: Optional[Stress] = None # eV/Angstrom^3 - virials: Optional[Virials] = None # eV - dipole: Optional[Vector] = None # Debye - charges: Optional[Charges] = None # atomic unit + properties: Dict[str, Any] + property_weights: Dict[str, float] cell: Optional[Cell] = None pbc: Optional[Pbc] = None weight: float = 1.0 # weight of config in loss - energy_weight: float = 1.0 # weight of config energy in loss - forces_weight: float = 1.0 # weight of config forces in loss - stress_weight: float = 1.0 # weight of config stress in loss - virials_weight: float = 1.0 # weight of config virial in loss config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config head: Optional[str] = "Default" # head used to compute the config @@ -86,14 +108,9 @@ def random_train_valid_split( def config_from_atoms_list( atoms_list: List[ase.Atoms], - energy_key="REF_energy", - forces_key="REF_forces", - stress_key="REF_stress", - virials_key="REF_virials", - dipole_key="REF_dipole", - charges_key="REF_charges", - head_key="head", + key_specification: KeySpecification, config_type_weights: Optional[Dict[str, float]] = None, + head_name: str = "Default", ) -> Configurations: """Convert list of ase.Atoms into Configurations""" if config_type_weights is None: @@ -104,14 +121,9 @@ def config_from_atoms_list( all_configs.append( config_from_atoms( atoms, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, config_type_weights=config_type_weights, + head_name=head_name, ) ) return all_configs @@ -119,26 +131,14 @@ def config_from_atoms_list( def config_from_atoms( atoms: ase.Atoms, - energy_key="REF_energy", - forces_key="REF_forces", - stress_key="REF_stress", - virials_key="REF_virials", - dipole_key="REF_dipole", - charges_key="REF_charges", - head_key="head", + key_specification: KeySpecification = KeySpecification(), config_type_weights: Optional[Dict[str, float]] = None, + head_name: str = "Default", ) -> Configuration: """Convert ase.Atoms to Configuration""" if config_type_weights is None: config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - energy = atoms.info.get(energy_key, None) # eV - forces = atoms.arrays.get(forces_key, None) # eV / Ang - stress = atoms.info.get(stress_key, None) # eV / Ang ^ 3 - virials = atoms.info.get(virials_key, None) - dipole = atoms.info.get(dipole_key, None) # Debye - # Charges default to 0 instead of None if not found - charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit atomic_numbers = np.array( [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] ) @@ -148,45 +148,29 @@ def config_from_atoms( weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( config_type, 1.0 ) - energy_weight = atoms.info.get("config_energy_weight", 1.0) - forces_weight = atoms.info.get("config_forces_weight", 1.0) - stress_weight = atoms.info.get("config_stress_weight", 1.0) - virials_weight = atoms.info.get("config_virials_weight", 1.0) - - head = atoms.info.get(head_key, "Default") - - # fill in missing quantities but set their weight to 0.0 - if energy is None: - energy = 0.0 - energy_weight = 0.0 - if forces is None: - forces = np.zeros(np.shape(atoms.positions)) - forces_weight = 0.0 - if stress is None: - stress = np.zeros(6) - stress_weight = 0.0 - if virials is None: - virials = np.zeros((3, 3)) - virials_weight = 0.0 - if dipole is None: - dipole = np.zeros(3) - # dipoles_weight = 0.0 + + properties = {} + property_weights = {} + for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): + property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0) + + for name, atoms_key in key_specification.info_keys.items(): + properties[name] = atoms.info.get(atoms_key, None) + if not atoms_key in atoms.info: + property_weights[name] = 0.0 + + for name, atoms_key in key_specification.arrays_keys.items(): + properties[name] = atoms.arrays.get(atoms_key, None) + if not atoms_key in atoms.arrays: + property_weights[name] = 0.0 return Configuration( atomic_numbers=atomic_numbers, positions=atoms.get_positions(), - energy=energy, - forces=forces, - stress=stress, - virials=virials, - dipole=dipole, - charges=charges, + properties=properties, weight=weight, - head=head, - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, - virials_weight=virials_weight, + property_weights=property_weights, + head=head_name, config_type=config_type, pbc=pbc, cell=cell, @@ -213,23 +197,20 @@ def test_config_types( def load_from_xyz( file_path: str, config_type_weights: Dict, - energy_key: str = "REF_energy", - forces_key: str = "REF_forces", - stress_key: str = "REF_stress", - virials_key: str = "REF_virials", - dipole_key: str = "REF_dipole", - charges_key: str = "REF_charges", - head_key: str = "head", + key_specification: KeySpecification, head_name: str = "Default", extract_atomic_energies: bool = False, keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") + energy_key = key_specification.info_keys["energy"] + forces_key = key_specification.arrays_keys["forces"] + stress_key = key_specification.info_keys["stress"] if energy_key == "energy": logging.warning( "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." ) - energy_key = "REF_energy" + key_specification.info_keys["energy"] = "REF_energy" for atoms in atoms_list: try: atoms.info["REF_energy"] = atoms.get_potential_energy() @@ -240,7 +221,7 @@ def load_from_xyz( logging.warning( "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." ) - forces_key = "REF_forces" + key_specification.arrays_keys["forces"] = "REF_forces" for atoms in atoms_list: try: atoms.arrays["REF_forces"] = atoms.get_forces() @@ -251,7 +232,7 @@ def load_from_xyz( logging.warning( "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." ) - stress_key = "REF_stress" + key_specification.info_keys["stress"] = "REF_stress" for atoms in atoms_list: try: atoms.info["REF_stress"] = atoms.get_stress() @@ -265,7 +246,7 @@ def load_from_xyz( atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): - atoms.info[head_key] = head_name + atoms.info[key_specification.info_keys["head"]] = head_name isolated_atom_config = ( len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" ) @@ -291,13 +272,8 @@ def load_from_xyz( configs = config_from_atoms_list( atoms_list, config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, + head_name=head_name, ) return atomic_energies_dict, configs @@ -314,7 +290,7 @@ def compute_average_E0s( A = np.zeros((len_train, len_zs)) B = np.zeros(len_train) for i in range(len_train): - B[i] = collections_train[i].energy + B[i] = collections_train[i].properties["energy"] for j, z in enumerate(z_table.zs): A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) try: @@ -335,26 +311,7 @@ def compute_average_E0s( def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: with h5py.File(out_name, "w") as f: for i, data in enumerate(dataset): - grp = f.create_group(f"config_{i}") - grp["num_nodes"] = data.num_nodes - grp["edge_index"] = data.edge_index - grp["positions"] = data.positions - grp["shifts"] = data.shifts - grp["unit_shifts"] = data.unit_shifts - grp["cell"] = data.cell - grp["node_attrs"] = data.node_attrs - grp["weight"] = data.weight - grp["energy_weight"] = data.energy_weight - grp["forces_weight"] = data.forces_weight - grp["stress_weight"] = data.stress_weight - grp["virials_weight"] = data.virials_weight - grp["forces"] = data.forces - grp["energy"] = data.energy - grp["stress"] = data.stress - grp["virials"] = data.virials - grp["dipole"] = data.dipole - grp["charges"] = data.charges - grp["head"] = data.head + save_AtomicData_to_HDF5(data, i, f) def save_AtomicData_to_HDF5(data, i, h5_file) -> None: @@ -387,20 +344,15 @@ def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> N subgroup = grp.create_group(subgroup_name) subgroup["atomic_numbers"] = write_value(config.atomic_numbers) subgroup["positions"] = write_value(config.positions) - subgroup["energy"] = write_value(config.energy) - subgroup["forces"] = write_value(config.forces) - subgroup["stress"] = write_value(config.stress) - subgroup["virials"] = write_value(config.virials) - subgroup["head"] = write_value(config.head) - subgroup["dipole"] = write_value(config.dipole) - subgroup["charges"] = write_value(config.charges) + properties_subgrp = subgroup.create_group("properties") + for key, value in config.properties.items(): + properties_subgrp[key] = write_value(value) subgroup["cell"] = write_value(config.cell) subgroup["pbc"] = write_value(config.pbc) subgroup["weight"] = write_value(config.weight) - subgroup["energy_weight"] = write_value(config.energy_weight) - subgroup["forces_weight"] = write_value(config.forces_weight) - subgroup["stress_weight"] = write_value(config.stress_weight) - subgroup["virials_weight"] = write_value(config.virials_weight) + weights_subgrp = subgroup.create_group("property_weights") + for key, value in config.property_weights.items(): + weights_subgrp[key] = write_value(value) subgroup["config_type"] = write_value(config.config_type) diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab43..05578ac8 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -74,7 +74,7 @@ def __init__( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) if heads is None: - heads = ["default"] + heads = ["Default"] self.heads = heads if isinstance(correlation, int): correlation = [correlation] * num_interactions diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index e492c827..9db0ccd5 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -428,6 +428,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str, default="REF_dipole", ) + parser.add_argument( + "--head_key", + help="Key of head in training xyz", + type=str, + default="head", + ) parser.add_argument( "--charges_key", help="Key of atomic charges in training xyz", @@ -847,6 +853,12 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: type=int, default=123, ) + parser.add_argument( + "--head_key", + help="Key of head in training xyz", + type=str, + default="head", + ) return parser diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index ffde107f..16f6e491 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -8,6 +8,7 @@ import torch from mace.cli.fine_tuning_select import select_samples +from mace.data import KeySpecification from mace.tools.scripts_utils import ( SubsetCollection, dict_to_namespace, @@ -26,12 +27,7 @@ class HeadConfig: statistics_file: Optional[str] = None valid_fraction: Optional[float] = None config_type_weights: Optional[Dict[str, float]] = None - energy_key: Optional[str] = None - forces_key: Optional[str] = None - stress_key: Optional[str] = None - virials_key: Optional[str] = None - dipole_key: Optional[str] = None - charges_key: Optional[str] = None + key_specification: Optional[KeySpecification] = None keep_isolated_atoms: Optional[bool] = None atomic_numbers: Optional[Union[List[int], List[str]]] = None mean: Optional[float] = None @@ -47,7 +43,11 @@ class HeadConfig: def dict_head_to_dataclass( head: Dict[str, Any], head_name: str, args: argparse.Namespace ) -> HeadConfig: - + # parser+head args that have no defaults but are required + if (args.train_file is None) and (head.get("train_file", None) is None): + raise ValueError( + "train file is not set in the head config yaml or via command line args" + ) return HeadConfig( head_name=head_name, train_file=head.get("train_file", args.train_file), @@ -65,40 +65,33 @@ def dict_head_to_dataclass( mean=head.get("mean", args.mean), std=head.get("std", args.std), avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), - energy_key=head.get("energy_key", args.energy_key), - forces_key=head.get("forces_key", args.forces_key), - stress_key=head.get("stress_key", args.stress_key), - virials_key=head.get("virials_key", args.virials_key), - dipole_key=head.get("dipole_key", args.dipole_key), - charges_key=head.get("charges_key", args.charges_key), + key_specification=head["key_specification"], keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), ) def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: return { - "default": { + "Default": { "train_file": args.train_file, "valid_file": args.valid_file, "test_file": args.test_file, "test_dir": args.test_dir, "E0s": args.E0s, "statistics_file": args.statistics_file, + "key_specification": args.key_specification, "valid_fraction": args.valid_fraction, "config_type_weights": args.config_type_weights, - "energy_key": args.energy_key, - "forces_key": args.forces_key, - "stress_key": args.stress_key, - "virials_key": args.virials_key, - "dipole_key": args.dipole_key, - "charges_key": args.charges_key, "keep_isolated_atoms": args.keep_isolated_atoms, } } def assemble_mp_data( - args: argparse.Namespace, tag: str, head_configs: List[HeadConfig] + args: argparse.Namespace, + tag: str, + head_configs: List[HeadConfig], + head_config_pt: HeadConfig, ) -> Dict[str, Any]: try: checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" @@ -169,13 +162,8 @@ def assemble_mp_data( config_type_weights=None, test_path=None, seed=args.seed, - energy_key="energy", - forces_key="forces", - stress_key="stress", + key_specification=head_config_pt.key_specification, head_name="pt_head", - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, ) return collections_mp diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index d20e942b..9f9fe4c2 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -21,6 +21,7 @@ from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules, tools +from mace.data import KeySpecification from mace.tools import evaluate from mace.tools.train import SWAContainer @@ -32,80 +33,71 @@ class SubsetCollection: tests: List[Tuple[str, data.Configurations]] +def log_dataset_contents( + dataset: data.Configurations, dataset_name: str, keyspec: KeySpecification +) -> None: + all_property_names = list(keyspec.info_keys.keys()) + list( + keyspec.arrays_keys.keys() + ) + log_string = f"{dataset_name} [" + for prop_name in all_property_names: + if prop_name == "dipole": + log_string += f"{prop_name} components: {int(np.sum([np.sum(config.property_weights[prop_name]) for config in dataset]))}, " + else: + log_string += f"{prop_name}: {int(np.sum([config.property_weights[prop_name] for config in dataset]))}, " + log_string = log_string[:-2] + "]" + logging.info(log_string) + + def get_dataset_from_xyz( work_dir: str, train_path: str, valid_path: Optional[str], valid_fraction: float, config_type_weights: Dict, + key_specification: KeySpecification, test_path: str = None, seed: int = 1234, keep_isolated_atoms: bool = False, head_name: str = "Default", - energy_key: str = "REF_energy", - forces_key: str = "REF_forces", - stress_key: str = "REF_stress", - virials_key: str = "virials", - dipole_key: str = "dipoles", - charges_key: str = "charges", - head_key: str = "head", ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" atomic_energies_dict, all_train_configs = data.load_from_xyz( file_path=train_path, config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, extract_atomic_energies=True, keep_isolated_atoms=keep_isolated_atoms, head_name=head_name, ) - logging.info( - f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" - ) + log_dataset_contents(all_train_configs, "Training set", key_specification) + if valid_path is not None: _, valid_configs = data.load_from_xyz( file_path=valid_path, config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, extract_atomic_energies=False, head_name=head_name, ) - logging.info( - f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" - ) + log_dataset_contents(valid_configs, "Validation set", key_specification) train_configs = all_train_configs else: train_configs, valid_configs = data.random_train_valid_split( all_train_configs, valid_fraction, seed, work_dir ) - logging.info( - f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + log_dataset_contents( + train_configs, "Random Split Training set", key_specification + ) + log_dataset_contents( + valid_configs, "Random Split Validation set", key_specification ) - test_configs = [] if test_path is not None: _, all_test_configs = data.load_from_xyz( file_path=test_path, config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - dipole_key=dipole_key, - stress_key=stress_key, - virials_key=virials_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, extract_atomic_energies=False, head_name=head_name, ) @@ -115,9 +107,7 @@ def get_dataset_from_xyz( f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" ) for name, tmp_configs in test_configs: - logging.info( - f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" - ) + log_dataset_contents(tmp_configs, f"Test set {name}", key_specification) return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), diff --git a/tests/test_data.py b/tests/test_data.py index 9e0c49e6..6710ecdd 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -29,14 +29,20 @@ class TestAtomicData: [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, ) config_2 = deepcopy(config) config_2.positions = config.positions + 0.01 diff --git a/tests/test_foundations.py b/tests/test_foundations.py index 03ea85c3..8805465b 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -27,21 +27,38 @@ config = data.Configuration( atomic_numbers=molecule("H2COH").numbers, positions=molecule("H2COH").positions, - forces=molecule("H2COH").positions, - energy=-1.5, - charges=molecule("H2COH").numbers, - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, ) + # Created the rotated environment rot = R.from_euler("z", 60, degrees=True).as_matrix() positions_rotated = np.array(rot @ config.positions.T).T config_rotated = data.Configuration( atomic_numbers=molecule("H2COH").numbers, positions=positions_rotated, - forces=molecule("H2COH").positions, - energy=-1.5, - charges=molecule("H2COH").numbers, - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, ) table = tools.AtomicNumberTable([1, 6, 8]) atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float) @@ -110,10 +127,18 @@ def test_multi_reference(): config_multi = data.Configuration( atomic_numbers=molecule("H2COH").numbers, positions=molecule("H2COH").positions, - forces=molecule("H2COH").positions, - energy=-1.5, - charges=molecule("H2COH").numbers, - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, head="MP2", ) table_multi = tools.AtomicNumberTable([1, 6, 8]) diff --git a/tests/test_models.py b/tests/test_models.py index 8e8c60da..7a39b22a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -18,16 +18,24 @@ [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, - charges=np.array([-2.0, 1.0, 1.0]), - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "charges": np.array([-2.0, 1.0, 1.0]), + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, ) # Created the rotated environment rot = R.from_euler("z", 60, degrees=True).as_matrix() @@ -35,16 +43,24 @@ config_rotated = data.Configuration( atomic_numbers=np.array([8, 1, 1]), positions=positions_rotated, - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, - charges=np.array([-2.0, 1.0, 1.0]), - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "charges": np.array([-2.0, 1.0, 1.0]), + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, ) table = tools.AtomicNumberTable([1, 8]) atomic_energies = np.array([1.0, 3.0], dtype=float) diff --git a/tests/test_modules.py b/tests/test_modules.py index b99d7d6d..6afcccfb 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -30,15 +30,22 @@ def _config(): [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, - stress=np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "stress": np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "stress": 1.0, + }, ) @@ -58,14 +65,20 @@ def _config1(): [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, head="DFT", ) @@ -81,14 +94,20 @@ def _config2(): [0.1, 1.1, 0.1], ] ), - forces=np.array( - [ - [0.1, -1.2, 0.1], - [1.1, 0.3, 0.1], - [0.1, 1.2, 0.4], - ] - ), - energy=-1.4, + properties={ + "forces": np.array( + [ + [0.1, -1.2, 0.1], + [1.1, 0.3, 0.1], + [0.1, 1.2, 0.4], + ] + ), + "energy": -1.4, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, head="MP2", ) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index e0258bd4..c1cb28c4 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -110,8 +110,8 @@ def test_preprocess_data(tmp_path, sample_configs): config = f["config_batch_0"]["config_0"] assert "atomic_numbers" in config assert "positions" in config - assert "energy" in config - assert "forces" in config + assert "energy" in config["properties"] + assert "forces" in config["properties"] original_energies = [ config.info["REF_energy"] @@ -134,19 +134,19 @@ def test_preprocess_data(tmp_path, sample_configs): config = batch[config_key] assert "atomic_numbers" in config assert "positions" in config - assert "energy" in config - assert "forces" in config + assert "energy" in config["properties"] + assert "forces" in config["properties"] - h5_energies.append(config["energy"][()]) - h5_forces.append(config["forces"][()]) + h5_energies.append(config["properties"]["energy"][()]) + h5_forces.append(config["properties"]["forces"][()]) for val_file in val_files: with h5py.File(val_file, "r") as f: for _, batch in f.items(): for config_key in batch.keys(): config = batch[config_key] - h5_energies.append(config["energy"][()]) - h5_forces.append(config["forces"][()]) + h5_energies.append(config["properties"]["energy"][()]) + h5_forces.append(config["properties"]["forces"][()]) print("Original energies", original_energies) print("H5 energies", h5_energies) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index ca196c47..153acfdf 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -377,7 +377,10 @@ def test_run_train_multihead(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head="CCD", ) Es = [] @@ -563,12 +566,15 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - Es = [] for at in fitting_configs: + config_head = at.info.get("head", "MP2") + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, + ) at.calc = calc Es.append(at.get_potential_energy()) @@ -680,15 +686,19 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - Es = [] for at in fitting_configs: + config_head = at.info.get("head", "MP2") + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, + ) at.calc = calc Es.append(at.get_potential_energy()) + print("Es", Es) # from a run on 20/08/2024 on commit ref_Es = [ @@ -833,7 +843,10 @@ def test_run_train_multihead_replay_custum_finetuning( # Load and test the finetuned model calc = MACECalculator( - model_paths=tmp_path / "finetuned.model", device="cpu", default_dtype="float64" + model_paths=tmp_path / "finetuned.model", + device="cpu", + default_dtype="float64", + head="pt_head", ) Es = [] diff --git a/tests/test_run_train_allkeys.py b/tests/test_run_train_allkeys.py new file mode 100644 index 00000000..275b13bc --- /dev/null +++ b/tests/test_run_train_allkeys.py @@ -0,0 +1,439 @@ +import os +import subprocess +import sys +from copy import deepcopy +from pathlib import Path + +import ase.io +import numpy as np +import pytest +from ase.atoms import Atoms + +from mace.calculators.mace import MACECalculator + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "device": "cpu", + "seed": 5, + "loss": "weighted", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "interaction_first": "RealAgnosticResidualInteractionBlock", + "batch_size": 1, + "valid_batch_size": 1, + "num_samples_pt": 50, + "subselect_pt": "random", + "eval_interval": 2, + "num_radial_basis": 10, + "r_max": 6.0, + "default_dtype": "float64", +} + + +def configs_numbered_keys(): + np.random.seed(0) + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + + energies = list(np.random.normal(0.1, size=15)) + forces = list(np.random.normal(0.1, size=(15, 3, 3))) + + trial_configs_lists = [] + # some keys present, some not + keys_to_use = ( + ["REF_energy"] + + ["2_energy"] * 2 + + ["3_energy"] * 3 + + ["4_energy"] * 4 + + ["5_energy"] * 5 + ) + + force_keys_to_use = ( + ["REF_forces"] + + ["2_forces"] * 2 + + ["3_forces"] * 3 + + ["4_forces"] * 4 + + ["5_forces"] * 5 + ) + + for ind in range(15): + c = deepcopy(water) + c.info[keys_to_use[ind]] = energies[ind] + c.arrays[force_keys_to_use[ind]] = forces[ind] + c.positions += np.random.normal(0.1, size=(3, 3)) + trial_configs_lists.append(c) + + return trial_configs_lists + + +def trial_yamls_and_and_expected(): + yamls = {} + command_line_kwargs = {"energy_key": "2_energy", "forces_key": "2_forces"} + + yamls["no_heads"] = {} + + yamls["one_head_no_dicts"] = { + "heads": { + "Default": { + "energy_key": "3_energy", + } + } + } + + yamls["one_head_with_dicts"] = { + "heads": { + "Default": { + "info_keys": { + "energy": "3_energy", + }, + "arrays_keys": { + "forces": "3_forces", + }, + } + } + } + + yamls["two_heads_no_dicts"] = { + "heads": { + "dft": { + "train_file": "fit_multihead_dft.xyz", + "energy_key": "3_energy", + }, + "mp2": { + "train_file": "fit_multihead_mp2.xyz", + "energy_key": "4_energy", + }, + } + } + + yamls["two_heads_mixed"] = { + "heads": { + "dft": { + "train_file": "fit_multihead_dft.xyz", + "info_keys": { + "energy": "3_energy", + }, + "arrays_keys": { + "forces": "3_forces", + }, + "forces_key": "4_forces", + }, + "mp2": { + "train_file": "fit_multihead_mp2.xyz", + "energy_key": "4_energy", + }, + } + } + all_arg_sets = { + "with_command_line": { + key: {**command_line_kwargs, **value} for key, value in yamls.items() + }, + "without_command_line": yamls, + } + + all_expected_outputs = { + "with_command_line": { + "no_heads": [ + 1.0037831178668188, + 1.0183291323603265, + 1.0120784084221528, + 0.9935695881012243, + 1.0021641561865526, + 0.9999135609205868, + 0.9809440616323108, + 1.0025784765050076, + 1.0017901145495376, + 1.0136913185404515, + 1.006798563238269, + 1.0187758397828384, + 1.0180201540775071, + 1.0132368725061702, + 0.9998734173248169, + ], + "one_head_no_dicts": [ + 1.0028437510688613, + 1.0514693378041775, + 1.059933403321331, + 1.034719940573569, + 1.0438040675561824, + 1.019719477728329, + 0.9841759692947915, + 1.0435266573857496, + 1.0339501989779065, + 1.0501795448530264, + 1.0402594216704781, + 1.0604998765679152, + 1.0633411200246015, + 1.0539071190201297, + 1.0393496428177804, + ], + "one_head_with_dicts": [ + 0.8638341551096959, + 1.0078341354784144, + 1.0149701178418595, + 0.9945723048460148, + 1.0184158011731292, + 0.9992135295205004, + 0.8943420783639198, + 1.0327920054084088, + 0.9905731198078909, + 0.9838325204450648, + 1.0018725575620482, + 1.007263052421034, + 1.0335213929231966, + 1.0033503312511205, + 1.0174433894759563, + ], + "two_heads_no_dicts": [ + 0.9836377578288774, + 1.0196844186291318, + 1.0151628222871238, + 0.957307281711648, + 0.985574141310865, + 0.9629670134047853, + 0.9242583185138095, + 0.9807770070311039, + 0.9973679440479541, + 1.0221127246963275, + 1.0031807967874216, + 1.0358701219543687, + 1.0434208761164758, + 1.0235606028124515, + 0.9797494630655053, + ], + "two_heads_mixed": [ + 0.8664108574741868, + 0.9907166576278023, + 1.0051969372365164, + 0.978702477000018, + 1.025500166764692, + 0.9940095566375018, + 0.9034029726954119, + 1.0391739502744488, + 0.9717327061183668, + 0.972292103670355, + 1.0012510461663253, + 0.9978051155885286, + 1.0378611651753475, + 1.0003207628186224, + 1.0209509292189651, + ], + }, + "without_command_line": { + "no_heads": [ + 0.9352605307451007, + 0.991084559389268, + 0.9940350095024881, + 0.9953849198103668, + 0.9954705498032904, + 0.9964815693808411, + 0.9663142667436776, + 0.9947223808739147, + 0.9897776682803257, + 0.989027769690667, + 0.9910280920241263, + 0.992067980667518, + 0.9917276132506404, + 0.9902848752169671, + 0.9928585982942544, + ], + "one_head_no_dicts": [ + 0.9425342207393083, + 1.0149788456087416, + 1.0249228965652788, + 1.0247924743285792, + 1.02732103964481, + 1.0168852937950326, + 0.9771283495170653, + 1.0261776335561517, + 1.0130461033368028, + 1.0162619153561783, + 1.019995179866916, + 1.0209512298344965, + 1.0219971755636952, + 1.0195791901659124, + 1.0234662527729408, + ], + "one_head_with_dicts": [ + 0.8638341551096959, + 1.0078341354784144, + 1.0149701178418595, + 0.9945723048460148, + 1.0184158011731292, + 0.9992135295205004, + 0.8943420783639198, + 1.0327920054084088, + 0.9905731198078909, + 0.9838325204450648, + 1.0018725575620482, + 1.007263052421034, + 1.0335213929231966, + 1.0033503312511205, + 1.0174433894759563, + ], + "two_heads_no_dicts": [ + 0.9933763730233168, + 0.9986480398559268, + 1.0042486164355315, + 1.0025568793877726, + 1.0032598081704625, + 0.9926714183717912, + 0.9920385249670881, + 1.0020278841030676, + 1.0012474150830537, + 1.0039289677261019, + 1.0022718878661814, + 1.003586385624809, + 1.003436450009097, + 1.003805673887942, + 1.001450261102316, + ], + "two_heads_mixed": [ + 0.8781767864616707, + 0.9843563603794138, + 1.0145197579049248, + 0.9835060778675391, + 1.0419060462994596, + 0.9917393978520056, + 0.9091521032773944, + 1.0605463095070453, + 0.9685381713826684, + 0.9866493058823766, + 1.00305061187164, + 1.0051273128414386, + 1.037964258398104, + 1.0106663924241408, + 1.0274351814133602, + ], + }, + } + + list_of_all = [] + for key, value in all_arg_sets.items(): + for key2, value2 in value.items(): + list_of_all.append( + (value2, (key, key2), np.asarray(all_expected_outputs[key][key2])) + ) + + return list_of_all + + +def dict_to_yaml_str(data, indent=0): + yaml_str = "" + for key, value in data.items(): + yaml_str += " " * indent + str(key) + ":" + if isinstance(value, dict): + yaml_str += "\n" + dict_to_yaml_str(value, indent + 2) + else: + yaml_str += " " + str(value) + "\n" + return yaml_str + + +_trial_yamls_and_and_expected = trial_yamls_and_and_expected() + + +@pytest.mark.parametrize( + "yaml_contents, name, expected_value", _trial_yamls_and_and_expected +) +def test_key_specification_methods( + tmp_path, yaml_contents, name, expected_value +): + fitting_configs = configs_numbered_keys() + + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs) + ase.io.write(tmp_path / "duplicated_fit_multihead_dft.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = "fit_multihead_dft.xyz" + mace_params["E0s"] = "{1:0.0,8:1.0}" + mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz" + del mace_params["valid_fraction"] + mace_params["max_num_epochs"] = 1 # many tests to do + del mace_params["energy_key"] + del mace_params["forces_key"] + del mace_params["stress_key"] + + mace_params["name"] = "MACE_" + + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(dict_to_yaml_str(yaml_contents)) + if len(yaml_contents) > 0: + mace_params["config"] = str(tmp_path / "config.yaml") + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True) + assert p.returncode == 0 + + if "heads" in yaml_contents: + headname = list(yaml_contents["heads"].keys())[0] + else: + headname = "Default" + + calc = MACECalculator( + tmp_path / "MACE_.model", device="cpu", default_dtype="float64", head=headname + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print(name) + print("Es", Es) + + assert np.allclose(np.asarray(Es), expected_value) + + +# for creating values +def make_output(): + outputs = {} + for yaml_contents, name, expected_value in _trial_yamls_and_and_expected: + if name[0] not in outputs: + outputs[name[0]] = {} + expected = test_key_specification_methods( + Path("."), yaml_contents, name, expected_value, debug_test=False + ) + outputs[name[0]][name[1]] = expected + print(outputs)