From d641a4b0c3f42c4d87ebbfddb49eb34ff41c21df Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 2 Sep 2024 10:24:48 +0100 Subject: [PATCH 01/27] WIP keyspec implementation --- mace/cli/fine_tuning_select.py | 1 + mace/cli/run_train.py | 31 ++---- mace/data/__init__.py | 4 + mace/data/atomic_data.py | 65 +++++------ mace/data/hdf5_dataset.py | 21 ++-- mace/data/utils.py | 194 ++++++++++++--------------------- mace/tools/arg_parser.py | 6 + mace/tools/multihead_tools.py | 35 ++---- mace/tools/scripts_utils.py | 51 +++------ 9 files changed, 166 insertions(+), 242 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index f3b7462f..86511c6b 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -228,6 +228,7 @@ def select_samples( if args.descriptors is not None: logging.info("Loading descriptors") descriptors = np.load(args.descriptors, allow_pickle=True) + print(args.configs_pt) atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): atoms.info["mace_descriptors"] = descriptors[i] diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 59e76be3..b093fe2f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -54,6 +54,7 @@ ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable +from mace.data import get_keyspec_from_args def main() -> None: @@ -71,6 +72,9 @@ def run(args: argparse.Namespace) -> None: tag = tools.get_tag(name=args.name, seed=args.seed) args, input_log_messages = tools.check_args(args) + # set keyspec in args + args.key_specification = get_keyspec_from_args(args) + if args.device == "xpu": try: import intel_extension_for_pytorch as ipex @@ -152,7 +156,7 @@ def run(args: argparse.Namespace) -> None: args.multiheads_finetuning = False if args.heads is not None: - args.heads = ast.literal_eval(args.heads) + args.heads = ast.literal_eval(args.heads) # strings from command line else: args.heads = prepare_default_head(args) @@ -187,7 +191,6 @@ def run(args: argparse.Namespace) -> None: head_config.atomic_energies_dict = ast.literal_eval( statistics["atomic_energies"] ) - # Data preparation if head_config.train_file.endswith(".xyz"): if head_config.valid_file is not None: @@ -205,12 +208,7 @@ def run(args: argparse.Namespace) -> None: config_type_weights=config_type_weights, test_path=head_config.test_file, seed=args.seed, - energy_key=head_config.energy_key, - forces_key=head_config.forces_key, - stress_key=head_config.stress_key, - virials_key=head_config.virials_key, - dipole_key=head_config.dipole_key, - charges_key=head_config.charges_key, + key_specification=args.key_specification, head_name=head_config.head_name, keep_isolated_atoms=head_config.keep_isolated_atoms, ) @@ -269,12 +267,7 @@ def run(args: argparse.Namespace) -> None: config_type_weights=None, test_path=None, seed=args.seed, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, + key_specification=args.key_specification, head_name="pt_head", keep_isolated_atoms=args.keep_isolated_atoms, ) @@ -286,12 +279,7 @@ def run(args: argparse.Namespace) -> None: statistics_file=args.statistics_file, valid_fraction=args.valid_fraction, config_type_weights=None, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, + key_specification=args.key_specification, keep_isolated_atoms=args.keep_isolated_atoms, collections=collections, avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, @@ -302,6 +290,9 @@ def run(args: argparse.Namespace) -> None: f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" ) + #print(args.heads) + print(head_configs) + # Atomic number table # yapf: disable for head_config in head_configs: diff --git a/mace/data/__init__.py b/mace/data/__init__.py index c10a3698..fc44416c 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -13,6 +13,8 @@ save_configurations_as_HDF5, save_dataset_as_HDF5, test_config_types, + KeySpecification, + get_keyspec_from_args, ) __all__ = [ @@ -31,4 +33,6 @@ "dataset_from_sharded_hdf5", "save_AtomicData_to_HDF5", "save_configurations_as_HDF5", + "KeySpecification", + "get_keyspec_from_args", ] diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 814a23e0..fc43cdbb 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -116,6 +116,7 @@ def from_config( z_table: AtomicNumberTable, cutoff: float, heads: Optional[list] = None, + **kwargs, ) -> "AtomicData": if heads is None: heads = ["default"] @@ -140,69 +141,71 @@ def from_config( ).view(3, 3) ) + num_atoms = len(config.atomic_numbers) + weight = ( torch.tensor(config.weight, dtype=torch.get_default_dtype()) if config.weight is not None - else 1 + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) energy_weight = ( - torch.tensor(config.energy_weight, dtype=torch.get_default_dtype()) - if config.energy_weight is not None - else 1 + torch.tensor(config.property_weights.get("energy"), dtype=torch.get_default_dtype()) + if config.property_weights.get("energy") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) forces_weight = ( - torch.tensor(config.forces_weight, dtype=torch.get_default_dtype()) - if config.forces_weight is not None - else 1 + torch.tensor(config.property_weights.get("forces"), dtype=torch.get_default_dtype()) + if config.property_weights.get("forces") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) stress_weight = ( - torch.tensor(config.stress_weight, dtype=torch.get_default_dtype()) - if config.stress_weight is not None - else 1 + torch.tensor(config.property_weights.get("stress"), dtype=torch.get_default_dtype()) + if config.property_weights.get("stress") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) virials_weight = ( - torch.tensor(config.virials_weight, dtype=torch.get_default_dtype()) - if config.virials_weight is not None - else 1 + torch.tensor(config.property_weights.get("virials"), dtype=torch.get_default_dtype()) + if config.property_weights.get("virials") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) forces = ( - torch.tensor(config.forces, dtype=torch.get_default_dtype()) - if config.forces is not None - else None + torch.tensor(config.properties.get("forces"), dtype=torch.get_default_dtype()) + if config.properties.get("forces") is not None + else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype()) ) energy = ( - torch.tensor(config.energy, dtype=torch.get_default_dtype()) - if config.energy is not None - else None + torch.tensor(config.properties.get("energy"), dtype=torch.get_default_dtype()) + if config.properties.get("energy") is not None + else torch.tensor(0.0, dtype=torch.get_default_dtype()) ) stress = ( voigt_to_matrix( - torch.tensor(config.stress, dtype=torch.get_default_dtype()) + torch.tensor(config.properties.get("stress"), dtype=torch.get_default_dtype()) ).unsqueeze(0) - if config.stress is not None - else None + if config.properties.get("stress") is not None + else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) ) virials = ( voigt_to_matrix( - torch.tensor(config.virials, dtype=torch.get_default_dtype()) + torch.tensor(config.properties.get("virials"), dtype=torch.get_default_dtype()) ).unsqueeze(0) - if config.virials is not None - else None + if config.properties.get("virials") is not None + else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) ) dipole = ( - torch.tensor(config.dipole, dtype=torch.get_default_dtype()).unsqueeze(0) - if config.dipole is not None - else None + torch.tensor(config.properties.get("dipole"), dtype=torch.get_default_dtype()).unsqueeze(0) + if config.properties.get("dipole") is not None + else torch.zeros(1, 3, dtype=torch.get_default_dtype()) ) charges = ( - torch.tensor(config.charges, dtype=torch.get_default_dtype()) - if config.charges is not None - else None + torch.tensor(config.properties.get("charges"), dtype=torch.get_default_dtype()) + if config.properties.get("charges") is not None + else torch.zeros(num_atoms, dtype=torch.get_default_dtype()) ) return cls( diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 477ccd3f..acc868ad 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -48,20 +48,20 @@ def __getitem__(self, index): config_index = index % self.batch_size grp = self.file["config_batch_" + str(batch_index)] subgrp = grp["config_" + str(config_index)] + + properties = {} + property_weights = {} + for key in subgrp["properties"]: + properties[key] = unpack_value(subgrp["properties"][key][()]) + for key in subgrp["property_weights"]: + property_weights[key] = unpack_value(subgrp["property_weights"][key][()]) + config = Configuration( atomic_numbers=subgrp["atomic_numbers"][()], positions=subgrp["positions"][()], - energy=unpack_value(subgrp["energy"][()]), - forces=unpack_value(subgrp["forces"][()]), - stress=unpack_value(subgrp["stress"][()]), - virials=unpack_value(subgrp["virials"][()]), - dipole=unpack_value(subgrp["dipole"][()]), - charges=unpack_value(subgrp["charges"][()]), + properties=properties, weight=unpack_value(subgrp["weight"][()]), - energy_weight=unpack_value(subgrp["energy_weight"][()]), - forces_weight=unpack_value(subgrp["forces_weight"][()]), - stress_weight=unpack_value(subgrp["stress_weight"][()]), - virials_weight=unpack_value(subgrp["virials_weight"][()]), + property_weights=property_weights, config_type=unpack_value(subgrp["config_type"][()]), pbc=unpack_value(subgrp["pbc"][()]), cell=unpack_value(subgrp["cell"][()]), @@ -73,6 +73,7 @@ def __getitem__(self, index): z_table=self.z_table, cutoff=self.r_max, heads=self.kwargs.get("heads", ["Default"]), + **self.kwargs, ) return atomic_data diff --git a/mace/data/utils.py b/mace/data/utils.py index bb8e5448..cb781340 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -15,12 +15,7 @@ from mace.tools import AtomicNumberTable -Vector = np.ndarray # [3,] Positions = np.ndarray # [..., 3] -Forces = np.ndarray # [..., 3] -Stress = np.ndarray # [6, ], [3,3], [9, ] -Virials = np.ndarray # [6, ], [3,3], [9, ] -Charges = np.ndarray # [..., 1] Cell = np.ndarray # [3,3] Pbc = tuple # (3,) @@ -28,24 +23,49 @@ DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} +@dataclass +class KeySpecification: + info_keys: Dict[str, str] + arrays_keys: Dict[str, str] + + def update( + self, + info_keys: Optional[Dict[str, str]] = None, + arrays_keys: Optional[Dict[str, str]] = None, + ): + if info_keys is not None: + self.info_keys.update(info_keys) + if arrays_keys is not None: + self.arrays_keys.update(arrays_keys) + return self + + +def get_keyspec_from_args(args) -> KeySpecification: + """ info/array beheviour of new keys is set here """ + info_keys = { + "energy": args.energy_key, + "stress": args.stress_key, + "virials": args.virials_key, + "dipole": args.dipole_key, + "head": args.head_key, + } + arrays_keys = { + "forces": args.forces_key, + "charges": args.charges_key, + } + return KeySpecification(info_keys=info_keys, arrays_keys=arrays_keys) + + @dataclass class Configuration: atomic_numbers: np.ndarray positions: Positions # Angstrom - energy: Optional[float] = None # eV - forces: Optional[Forces] = None # eV/Angstrom - stress: Optional[Stress] = None # eV/Angstrom^3 - virials: Optional[Virials] = None # eV - dipole: Optional[Vector] = None # Debye - charges: Optional[Charges] = None # atomic unit + properties: Dict[str, np.ndarray] + property_weights: Dict[str, float] cell: Optional[Cell] = None pbc: Optional[Pbc] = None weight: float = 1.0 # weight of config in loss - energy_weight: float = 1.0 # weight of config energy in loss - forces_weight: float = 1.0 # weight of config forces in loss - stress_weight: float = 1.0 # weight of config stress in loss - virials_weight: float = 1.0 # weight of config virial in loss config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config head: Optional[str] = "Default" # head used to compute the config @@ -86,13 +106,7 @@ def random_train_valid_split( def config_from_atoms_list( atoms_list: List[ase.Atoms], - energy_key="REF_energy", - forces_key="REF_forces", - stress_key="REF_stress", - virials_key="REF_virials", - dipole_key="REF_dipole", - charges_key="REF_charges", - head_key="head", + key_specification: KeySpecification, config_type_weights: Optional[Dict[str, float]] = None, ) -> Configurations: """Convert list of ase.Atoms into Configurations""" @@ -104,13 +118,7 @@ def config_from_atoms_list( all_configs.append( config_from_atoms( atoms, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, config_type_weights=config_type_weights, ) ) @@ -119,26 +127,13 @@ def config_from_atoms_list( def config_from_atoms( atoms: ase.Atoms, - energy_key="REF_energy", - forces_key="REF_forces", - stress_key="REF_stress", - virials_key="REF_virials", - dipole_key="REF_dipole", - charges_key="REF_charges", - head_key="head", + key_specification: KeySpecification, config_type_weights: Optional[Dict[str, float]] = None, ) -> Configuration: """Convert ase.Atoms to Configuration""" if config_type_weights is None: config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - - energy = atoms.info.get(energy_key, None) # eV - forces = atoms.arrays.get(forces_key, None) # eV / Ang - stress = atoms.info.get(stress_key, None) # eV / Ang ^ 3 - virials = atoms.info.get(virials_key, None) - dipole = atoms.info.get(dipole_key, None) # Debye - # Charges default to 0 instead of None if not found - charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit + atomic_numbers = np.array( [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] ) @@ -148,45 +143,31 @@ def config_from_atoms( weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( config_type, 1.0 ) - energy_weight = atoms.info.get("config_energy_weight", 1.0) - forces_weight = atoms.info.get("config_forces_weight", 1.0) - stress_weight = atoms.info.get("config_stress_weight", 1.0) - virials_weight = atoms.info.get("config_virials_weight", 1.0) - - head = atoms.info.get(head_key, "Default") - - # fill in missing quantities but set their weight to 0.0 - if energy is None: - energy = 0.0 - energy_weight = 0.0 - if forces is None: - forces = np.zeros(np.shape(atoms.positions)) - forces_weight = 0.0 - if stress is None: - stress = np.zeros(6) - stress_weight = 0.0 - if virials is None: - virials = np.zeros((3, 3)) - virials_weight = 0.0 - if dipole is None: - dipole = np.zeros(3) - # dipoles_weight = 0.0 + head = atoms.info.get(key_specification.info_keys["head"], "Default") + properties = {} + property_weights = {} + for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): + property_weights[name] = atoms.info.get("config_" + name + "_weight", 1.0) + for name, atoms_key in key_specification.info_keys.items(): + properties[name] = atoms.info.get(atoms_key, None) + if not atoms_key in atoms.info: + property_weights[name] = 0.0 + for name, atoms_key in key_specification.arrays_keys.items(): + properties[name] = atoms.arrays.get(atoms_key, None) + if not atoms_key in atoms.arrays: + property_weights[name] = 0.0 + + + del properties["head"] + del property_weights["head"] return Configuration( atomic_numbers=atomic_numbers, positions=atoms.get_positions(), - energy=energy, - forces=forces, - stress=stress, - virials=virials, - dipole=dipole, - charges=charges, + properties=properties, weight=weight, + property_weights=property_weights, head=head, - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, - virials_weight=virials_weight, config_type=config_type, pbc=pbc, cell=cell, @@ -213,18 +194,15 @@ def test_config_types( def load_from_xyz( file_path: str, config_type_weights: Dict, - energy_key: str = "REF_energy", - forces_key: str = "REF_forces", - stress_key: str = "REF_stress", - virials_key: str = "REF_virials", - dipole_key: str = "REF_dipole", - charges_key: str = "REF_charges", - head_key: str = "head", + key_specification: KeySpecification, head_name: str = "Default", extract_atomic_energies: bool = False, keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") + energy_key = key_specification.info_keys["energy"] + forces_key = key_specification.arrays_keys["forces"] + stress_key = key_specification.info_keys["stress"] if energy_key == "energy": logging.warning( "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." @@ -265,7 +243,7 @@ def load_from_xyz( atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): - atoms.info[head_key] = head_name + atoms.info[key_specification.info_keys["head"]] = head_name isolated_atom_config = ( len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" ) @@ -291,13 +269,7 @@ def load_from_xyz( configs = config_from_atoms_list( atoms_list, config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, ) return atomic_energies_dict, configs @@ -335,26 +307,7 @@ def compute_average_E0s( def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: with h5py.File(out_name, "w") as f: for i, data in enumerate(dataset): - grp = f.create_group(f"config_{i}") - grp["num_nodes"] = data.num_nodes - grp["edge_index"] = data.edge_index - grp["positions"] = data.positions - grp["shifts"] = data.shifts - grp["unit_shifts"] = data.unit_shifts - grp["cell"] = data.cell - grp["node_attrs"] = data.node_attrs - grp["weight"] = data.weight - grp["energy_weight"] = data.energy_weight - grp["forces_weight"] = data.forces_weight - grp["stress_weight"] = data.stress_weight - grp["virials_weight"] = data.virials_weight - grp["forces"] = data.forces - grp["energy"] = data.energy - grp["stress"] = data.stress - grp["virials"] = data.virials - grp["dipole"] = data.dipole - grp["charges"] = data.charges - grp["head"] = data.head + save_AtomicData_to_HDF5(data, i, f) def save_AtomicData_to_HDF5(data, i, h5_file) -> None: @@ -387,20 +340,15 @@ def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> N subgroup = grp.create_group(subgroup_name) subgroup["atomic_numbers"] = write_value(config.atomic_numbers) subgroup["positions"] = write_value(config.positions) - subgroup["energy"] = write_value(config.energy) - subgroup["forces"] = write_value(config.forces) - subgroup["stress"] = write_value(config.stress) - subgroup["virials"] = write_value(config.virials) - subgroup["head"] = write_value(config.head) - subgroup["dipole"] = write_value(config.dipole) - subgroup["charges"] = write_value(config.charges) + properties_subgrp = subgroup.create_group("properties") + for key, value in config.properties.items(): + properties_subgrp[key] = write_value(value) subgroup["cell"] = write_value(config.cell) subgroup["pbc"] = write_value(config.pbc) subgroup["weight"] = write_value(config.weight) - subgroup["energy_weight"] = write_value(config.energy_weight) - subgroup["forces_weight"] = write_value(config.forces_weight) - subgroup["stress_weight"] = write_value(config.stress_weight) - subgroup["virials_weight"] = write_value(config.virials_weight) + weights_subgrp = subgroup.create_group("property_weights") + for key, value in config.property_weights.items(): + weights_subgrp[key] = write_value(value) subgroup["config_type"] = write_value(config.config_type) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index fd438990..50f0ee1f 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -405,6 +405,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str, default="REF_dipole", ) + parser.add_argument( + "--head_key", + help="Key of head in training xyz", + type=str, + default="head", + ) parser.add_argument( "--charges_key", help="Key of atomic charges in training xyz", diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 8892fa6d..468ae203 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -13,7 +13,7 @@ dict_to_namespace, get_dataset_from_xyz, ) - +from mace.data import get_keyspec_from_args, KeySpecification @dataclasses.dataclass class HeadConfig: @@ -26,12 +26,7 @@ class HeadConfig: statistics_file: Optional[str] = None valid_fraction: Optional[float] = None config_type_weights: Optional[Dict[str, float]] = None - energy_key: Optional[str] = None - forces_key: Optional[str] = None - stress_key: Optional[str] = None - virials_key: Optional[str] = None - dipole_key: Optional[str] = None - charges_key: Optional[str] = None + key_specification: Optional[KeySpecification] = None keep_isolated_atoms: Optional[bool] = None atomic_numbers: Optional[Union[List[int], List[str]]] = None mean: Optional[float] = None @@ -65,12 +60,7 @@ def dict_head_to_dataclass( mean=head.get("mean", args.mean), std=head.get("std", args.std), avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), - energy_key=head.get("energy_key", args.energy_key), - forces_key=head.get("forces_key", args.forces_key), - stress_key=head.get("stress_key", args.stress_key), - virials_key=head.get("virials_key", args.virials_key), - dipole_key=head.get("dipole_key", args.dipole_key), - charges_key=head.get("charges_key", args.charges_key), + key_specification=head.get("key_specification", args.key_specification), keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), ) @@ -86,12 +76,7 @@ def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: "statistics_file": args.statistics_file, "valid_fraction": args.valid_fraction, "config_type_weights": args.config_type_weights, - "energy_key": args.energy_key, - "forces_key": args.forces_key, - "stress_key": args.stress_key, - "virials_key": args.virials_key, - "dipole_key": args.dipole_key, - "charges_key": args.charges_key, + "key_specification": args.key_specification, "keep_isolated_atoms": args.keep_isolated_atoms, } } @@ -161,6 +146,11 @@ def assemble_mp_data( "default_dtype": args.default_dtype, } select_samples(dict_to_namespace(args_samples)) + mp_keyspec = get_keyspec_from_args(args) + mp_keyspec.update( + info_keys={"energy":"energy", "stress":"stress"}, + arrays_keys={"forces":"forces"}, + ) collections_mp, _ = get_dataset_from_xyz( work_dir=args.work_dir, train_path=f"mp_finetuning-{tag}.xyz", @@ -169,13 +159,8 @@ def assemble_mp_data( config_type_weights=None, test_path=None, seed=args.seed, - energy_key="energy", - forces_key="forces", - stress_key="stress", + key_specification=mp_keyspec, head_name="pt_head", - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, ) return collections_mp diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index f44390a6..54b29e81 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -23,6 +23,8 @@ from mace.tools import evaluate from mace.tools.train import SWAContainer +from mace.data import KeySpecification + @dataclasses.dataclass class SubsetCollection: @@ -37,74 +39,55 @@ def get_dataset_from_xyz( valid_path: Optional[str], valid_fraction: float, config_type_weights: Dict, + key_specification: KeySpecification, test_path: str = None, seed: int = 1234, keep_isolated_atoms: bool = False, head_name: str = "Default", - energy_key: str = "REF_energy", - forces_key: str = "REF_forces", - stress_key: str = "REF_stress", - virials_key: str = "virials", - dipole_key: str = "dipoles", - charges_key: str = "charges", - head_key: str = "head", ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" atomic_energies_dict, all_train_configs = data.load_from_xyz( file_path=train_path, config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, extract_atomic_energies=True, keep_isolated_atoms=keep_isolated_atoms, head_name=head_name, ) + num_energies = int(np.sum([config.property_weights["energy"] for config in all_train_configs])) + num_forces = int(np.sum([config.property_weights["forces"] * config.atomic_numbers.size for config in all_train_configs])) logging.info( - f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" + f"Training set [{len(all_train_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{train_path}'" ) if valid_path is not None: _, valid_configs = data.load_from_xyz( file_path=valid_path, config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, extract_atomic_energies=False, head_name=head_name, ) + num_energies = int(np.sum([config.property_weights["energy"] for config in valid_configs])) + num_forces = int(np.sum([config.property_weights["forces"] * config.atomic_numbers.size for config in valid_configs])) logging.info( - f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" + f"Training set [{len(valid_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{valid_path}'" ) train_configs = all_train_configs else: train_configs, valid_configs = data.random_train_valid_split( all_train_configs, valid_fraction, seed, work_dir ) + num_energies = int(np.sum([config.property_weights["energy"] for config in valid_configs])) + num_forces = int(np.sum([config.property_weights["forces"] * config.atomic_numbers.size for config in valid_configs])) logging.info( - f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + f"Validation set contains {len(valid_configs)} configs, [{num_energies} energy, {num_forces} forces]" ) - test_configs = [] if test_path is not None: _, all_test_configs = data.load_from_xyz( file_path=test_path, config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - dipole_key=dipole_key, - stress_key=stress_key, - virials_key=virials_key, - charges_key=charges_key, - head_key=head_key, + key_specification=key_specification, extract_atomic_energies=False, head_name=head_name, ) @@ -114,8 +97,10 @@ def get_dataset_from_xyz( f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" ) for name, tmp_configs in test_configs: + num_energies = int(np.sum([config.property_weights["energy"] for config in tmp_configs])) + num_forces = int(np.sum([config.property_weights["forces"] * config.atomic_numbers.size for config in tmp_configs])) logging.info( - f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" + f"{name}: {len(tmp_configs)} configs, {num_energies} energy, {num_forces} forces" ) return ( From d61f9171d54ccd9314c3d3197c755581d33c6e27 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 2 Sep 2024 14:54:23 +0100 Subject: [PATCH 02/27] WIP2 --- mace/calculators/mace.py | 3 ++- mace/cli/fine_tuning_select.py | 1 - mace/cli/run_train.py | 12 ++++----- mace/data/__init__.py | 4 +-- mace/data/hdf5_dataset.py | 5 ++-- mace/data/utils.py | 49 ++++++++++++++++++---------------- mace/tools/multihead_tools.py | 13 ++++++--- mace/tools/scripts_utils.py | 8 +++--- 8 files changed, 52 insertions(+), 43 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 292b114b..c4459788 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -198,7 +198,8 @@ def _create_result_tensors( return dict_of_tensors def _atoms_to_batch(self, atoms): - config = data.config_from_atoms(atoms, charges_key=self.charges_key) + keyspec = data.KeySpecification(info_keys={}, arrays_keys={"charges": self.charges_key}) + config = data.config_from_atoms(atoms, key_specification=keyspec) data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 86511c6b..f3b7462f 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -228,7 +228,6 @@ def select_samples( if args.descriptors is not None: logging.info("Loading descriptors") descriptors = np.load(args.descriptors, allow_pickle=True) - print(args.configs_pt) atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): atoms.info["mace_descriptors"] = descriptors[i] diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index b093fe2f..e88eb293 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -54,7 +54,7 @@ ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable -from mace.data import get_keyspec_from_args +from mace.data import update_keyspec_from_kwargs, KeySpecification def main() -> None: @@ -72,8 +72,9 @@ def run(args: argparse.Namespace) -> None: tag = tools.get_tag(name=args.name, seed=args.seed) args, input_log_messages = tools.check_args(args) - # set keyspec in args - args.key_specification = get_keyspec_from_args(args) + # default keyspec to update using heads dictionary + args.key_specification = KeySpecification() + update_keyspec_from_kwargs(args.key_specification, vars(args)) if args.device == "xpu": try: @@ -290,9 +291,6 @@ def run(args: argparse.Namespace) -> None: f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" ) - #print(args.heads) - print(head_configs) - # Atomic number table # yapf: disable for head_config in head_configs: @@ -362,6 +360,8 @@ def run(args: argparse.Namespace) -> None: for z in z_table.zs } + print('hey') + if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True diff --git a/mace/data/__init__.py b/mace/data/__init__.py index fc44416c..52c56c4d 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -14,7 +14,7 @@ save_dataset_as_HDF5, test_config_types, KeySpecification, - get_keyspec_from_args, + update_keyspec_from_kwargs, ) __all__ = [ @@ -34,5 +34,5 @@ "save_AtomicData_to_HDF5", "save_configurations_as_HDF5", "KeySpecification", - "get_keyspec_from_args", + "update_keyspec_from_kwargs", ] diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index acc868ad..2ddbaf7e 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -10,7 +10,7 @@ class HDF5Dataset(Dataset): - def __init__(self, file_path, r_max, z_table, **kwargs): + def __init__(self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs): super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments self.file_path = file_path self._file = None @@ -19,6 +19,7 @@ def __init__(self, file_path, r_max, z_table, **kwargs): self.length = len(self.file.keys()) * self.batch_size self.r_max = r_max self.z_table = z_table + self.atomic_dataclass = atomic_dataclass try: self.drop_last = bool(self.file.attrs["drop_last"]) except KeyError: @@ -68,7 +69,7 @@ def __getitem__(self, index): ) if config.head is None: config.head = self.kwargs.get("head") - atomic_data = AtomicData.from_config( + atomic_data = self.atomic_dataclass.from_config( config, z_table=self.z_table, cutoff=self.r_max, diff --git a/mace/data/utils.py b/mace/data/utils.py index cb781340..9fb38804 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -5,8 +5,8 @@ ########################################################################################### import logging -from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence, Tuple, Any import ase.data import ase.io @@ -25,8 +25,8 @@ @dataclass class KeySpecification: - info_keys: Dict[str, str] - arrays_keys: Dict[str, str] + info_keys: Dict[str, str] = field(default_factory=dict) + arrays_keys: Dict[str, str] = field(default_factory=dict) def update( self, @@ -40,27 +40,27 @@ def update( return self -def get_keyspec_from_args(args) -> KeySpecification: - """ info/array beheviour of new keys is set here """ - info_keys = { - "energy": args.energy_key, - "stress": args.stress_key, - "virials": args.virials_key, - "dipole": args.dipole_key, - "head": args.head_key, - } - arrays_keys = { - "forces": args.forces_key, - "charges": args.charges_key, - } - return KeySpecification(info_keys=info_keys, arrays_keys=arrays_keys) +def update_keyspec_from_kwargs(keyspec, keydict) -> KeySpecification: + # convert command line style property_key arguments into a keyspec + infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] + arrays = ["forces_key", "charges_key"] + info_keys = {} + arrays_keys = {} + for key in infos: + if key in keydict: + info_keys[key[:-4]] = keydict[key] + for key in arrays: + if key in keydict: + arrays_keys[key[:-4]] = keydict[key] + keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys) + return keyspec @dataclass class Configuration: atomic_numbers: np.ndarray positions: Positions # Angstrom - properties: Dict[str, np.ndarray] + properties: Dict[str, Any] property_weights: Dict[str, float] cell: Optional[Cell] = None pbc: Optional[Pbc] = None @@ -143,23 +143,26 @@ def config_from_atoms( weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( config_type, 1.0 ) - head = atoms.info.get(key_specification.info_keys["head"], "Default") + head_key = key_specification.info_keys.get("head", "default") + head = atoms.info.get(head_key, "Default") properties = {} property_weights = {} for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): property_weights[name] = atoms.info.get("config_" + name + "_weight", 1.0) + for name, atoms_key in key_specification.info_keys.items(): properties[name] = atoms.info.get(atoms_key, None) if not atoms_key in atoms.info: property_weights[name] = 0.0 + for name, atoms_key in key_specification.arrays_keys.items(): properties[name] = atoms.arrays.get(atoms_key, None) if not atoms_key in atoms.arrays: property_weights[name] = 0.0 - - del properties["head"] - del property_weights["head"] + if "head" in properties: + del properties["head"] + del property_weights["head"] return Configuration( atomic_numbers=atomic_numbers, diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 468ae203..02120809 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -13,7 +13,8 @@ dict_to_namespace, get_dataset_from_xyz, ) -from mace.data import get_keyspec_from_args, KeySpecification +from mace.data import update_keyspec_from_kwargs, KeySpecification +from copy import deepcopy @dataclasses.dataclass class HeadConfig: @@ -42,6 +43,10 @@ class HeadConfig: def dict_head_to_dataclass( head: Dict[str, Any], head_name: str, args: argparse.Namespace ) -> HeadConfig: + # priority is global args < head property_key values < head info_keys+arrays_keys + head_keyspec = deepcopy(args.key_specification) + update_keyspec_from_kwargs(head_keyspec, head) + head_keyspec.update(info_keys=head.get("info_keys", {}), arrays_keys=head.get("arrays_keys", {})) return HeadConfig( head_name=head_name, @@ -60,7 +65,7 @@ def dict_head_to_dataclass( mean=head.get("mean", args.mean), std=head.get("std", args.std), avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), - key_specification=head.get("key_specification", args.key_specification), + key_specification=head_keyspec, keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), ) @@ -76,7 +81,6 @@ def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: "statistics_file": args.statistics_file, "valid_fraction": args.valid_fraction, "config_type_weights": args.config_type_weights, - "key_specification": args.key_specification, "keep_isolated_atoms": args.keep_isolated_atoms, } } @@ -146,7 +150,8 @@ def assemble_mp_data( "default_dtype": args.default_dtype, } select_samples(dict_to_namespace(args_samples)) - mp_keyspec = get_keyspec_from_args(args) + mp_keyspec = KeySpecification() + update_keyspec_from_kwargs(mp_keyspec, vars(args)) mp_keyspec.update( info_keys={"energy":"energy", "stress":"stress"}, arrays_keys={"forces":"forces"}, diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 54b29e81..308d6246 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -55,7 +55,7 @@ def get_dataset_from_xyz( head_name=head_name, ) num_energies = int(np.sum([config.property_weights["energy"] for config in all_train_configs])) - num_forces = int(np.sum([config.property_weights["forces"] * config.atomic_numbers.size for config in all_train_configs])) + num_forces = int(np.sum([config.property_weights["forces"] for config in all_train_configs])) logging.info( f"Training set [{len(all_train_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{train_path}'" ) @@ -68,7 +68,7 @@ def get_dataset_from_xyz( head_name=head_name, ) num_energies = int(np.sum([config.property_weights["energy"] for config in valid_configs])) - num_forces = int(np.sum([config.property_weights["forces"] * config.atomic_numbers.size for config in valid_configs])) + num_forces = int(np.sum([config.property_weights["forces"] for config in valid_configs])) logging.info( f"Training set [{len(valid_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{valid_path}'" ) @@ -78,7 +78,7 @@ def get_dataset_from_xyz( all_train_configs, valid_fraction, seed, work_dir ) num_energies = int(np.sum([config.property_weights["energy"] for config in valid_configs])) - num_forces = int(np.sum([config.property_weights["forces"] * config.atomic_numbers.size for config in valid_configs])) + num_forces = int(np.sum([config.property_weights["forces"] for config in valid_configs])) logging.info( f"Validation set contains {len(valid_configs)} configs, [{num_energies} energy, {num_forces} forces]" ) @@ -98,7 +98,7 @@ def get_dataset_from_xyz( ) for name, tmp_configs in test_configs: num_energies = int(np.sum([config.property_weights["energy"] for config in tmp_configs])) - num_forces = int(np.sum([config.property_weights["forces"] * config.atomic_numbers.size for config in tmp_configs])) + num_forces = int(np.sum([config.property_weights["forces"] for config in tmp_configs])) logging.info( f"{name}: {len(tmp_configs)} configs, {num_energies} energy, {num_forces} forces" ) From a7cead18dd81ca30032b99002fef8d276bc26e12 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 2 Sep 2024 16:02:11 +0100 Subject: [PATCH 03/27] WIP3 --- mace/cli/run_train.py | 9 ++++++++- mace/tools/multihead_tools.py | 10 ++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 5b5f5fec..5075937f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -250,14 +250,21 @@ def run(args: argparse.Namespace) -> None: "Using foundation model for multiheads finetuning with Materials Project data" ) heads = list(dict.fromkeys(["pt_head"] + heads)) + mp_keyspec = KeySpecification() + update_keyspec_from_kwargs(mp_keyspec, vars(args)) + mp_keyspec.update( + info_keys={"energy":"energy", "stress":"stress"}, + arrays_keys={"forces":"forces"}, + ) head_config_pt = HeadConfig( head_name="pt_head", E0s="foundation", statistics_file=args.statistics_file, + key_specification=mp_keyspec, compute_avg_num_neighbors=False, avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, ) - collections = assemble_mp_data(args, tag, head_configs) + collections = assemble_mp_data(args, tag, head_configs, head_config_pt) head_config_pt.collections = collections head_config_pt.train_file = f"mp_finetuning-{tag}.xyz" head_configs.append(head_config_pt) diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 11ab225c..5fff4add 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -87,7 +87,7 @@ def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: def assemble_mp_data( - args: argparse.Namespace, tag: str, head_configs: List[HeadConfig] + args: argparse.Namespace, tag: str, head_configs: List[HeadConfig], head_config_pt: HeadConfig ) -> Dict[str, Any]: try: checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" @@ -150,12 +150,6 @@ def assemble_mp_data( "default_dtype": args.default_dtype, } select_samples(dict_to_namespace(args_samples)) - mp_keyspec = KeySpecification() - update_keyspec_from_kwargs(mp_keyspec, vars(args)) - mp_keyspec.update( - info_keys={"energy":"energy", "stress":"stress"}, - arrays_keys={"forces":"forces"}, - ) collections_mp, _ = get_dataset_from_xyz( work_dir=args.work_dir, train_path=f"mp_finetuning-{tag}.xyz", @@ -164,7 +158,7 @@ def assemble_mp_data( config_type_weights=None, test_path=None, seed=args.seed, - key_specification=mp_keyspec, + key_specification=head_config_pt.key_specification, head_name="pt_head", keep_isolated_atoms=args.keep_isolated_atoms, ) From 818bcc8821815839158af55ae1c84260d0e9adc5 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 2 Sep 2024 17:26:58 +0100 Subject: [PATCH 04/27] fixed some tests --- mace/calculators/mace.py | 1 + mace/data/utils.py | 4 +-- tests/test_data.py | 22 +++++++++------ tests/test_foundations.py | 51 ++++++++++++++++++++++++++--------- tests/test_models.py | 56 +++++++++++++++++++++++++-------------- tests/test_modules.py | 26 +++++++++++------- tests/test_run_train.py | 4 +-- 7 files changed, 109 insertions(+), 55 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index c4459788..61b494a0 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -237,6 +237,7 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if self.model_type in ["MACE", "EnergyDipoleMACE"]: batch = self._clone_batch(batch_base) node_heads = batch["head"][batch["batch"]] + print(node_heads) num_atoms_arange = torch.arange(batch["positions"].shape[0]) node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ num_atoms_arange, node_heads diff --git a/mace/data/utils.py b/mace/data/utils.py index 9fb38804..933a08ae 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -127,7 +127,7 @@ def config_from_atoms_list( def config_from_atoms( atoms: ase.Atoms, - key_specification: KeySpecification, + key_specification: KeySpecification = KeySpecification(), config_type_weights: Optional[Dict[str, float]] = None, ) -> Configuration: """Convert ase.Atoms to Configuration""" @@ -143,7 +143,7 @@ def config_from_atoms( weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( config_type, 1.0 ) - head_key = key_specification.info_keys.get("head", "default") + head_key = key_specification.info_keys.get("head", "head") head = atoms.info.get(head_key, "Default") properties = {} property_weights = {} diff --git a/tests/test_data.py b/tests/test_data.py index e893f03c..70b5b610 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -29,14 +29,20 @@ class TestAtomicData: [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, ) config_2 = deepcopy(config) config_2.positions = config.positions + 0.01 diff --git a/tests/test_foundations.py b/tests/test_foundations.py index b1724629..33d338e2 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -17,21 +17,38 @@ config = data.Configuration( atomic_numbers=molecule("H2COH").numbers, positions=molecule("H2COH").positions, - forces=molecule("H2COH").positions, - energy=-1.5, - charges=molecule("H2COH").numbers, - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, ) + # Created the rotated environment rot = R.from_euler("z", 60, degrees=True).as_matrix() positions_rotated = np.array(rot @ config.positions.T).T config_rotated = data.Configuration( atomic_numbers=molecule("H2COH").numbers, positions=positions_rotated, - forces=molecule("H2COH").positions, - energy=-1.5, - charges=molecule("H2COH").numbers, - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, ) table = tools.AtomicNumberTable([1, 6, 8]) atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float) @@ -100,11 +117,19 @@ def test_multi_reference(): config_multi = data.Configuration( atomic_numbers=molecule("H2COH").numbers, positions=molecule("H2COH").positions, - forces=molecule("H2COH").positions, - energy=-1.5, - charges=molecule("H2COH").numbers, - dipole=np.array([-1.5, 1.5, 2.0]), - head="MP2", + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + head='MP2' ) table_multi = tools.AtomicNumberTable([1, 6, 8]) atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) diff --git a/tests/test_models.py b/tests/test_models.py index 8e8c60da..e114e308 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -18,16 +18,24 @@ [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, - charges=np.array([-2.0, 1.0, 1.0]), - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "charges":np.array([-2.0, 1.0, 1.0]), + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, ) # Created the rotated environment rot = R.from_euler("z", 60, degrees=True).as_matrix() @@ -35,16 +43,24 @@ config_rotated = data.Configuration( atomic_numbers=np.array([8, 1, 1]), positions=positions_rotated, - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, - charges=np.array([-2.0, 1.0, 1.0]), - dipole=np.array([-1.5, 1.5, 2.0]), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "charges":np.array([-2.0, 1.0, 1.0]), + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, ) table = tools.AtomicNumberTable([1, 8]) atomic_energies = np.array([1.0, 3.0], dtype=float) diff --git a/tests/test_modules.py b/tests/test_modules.py index 5539ceb1..3097c735 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -23,16 +23,22 @@ [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, - # stress if voigt 6 notation - stress=np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "stress":np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "stress": 1.0, + }, ) table = AtomicNumberTable([1, 8]) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 6c41ce0f..cd28823a 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -319,7 +319,7 @@ def test_run_train_multihead(tmp_path, fitting_configs): mace_params["loss"] = "weighted" mace_params["hidden_irreps"] = "128x0e" mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float32" + mace_params["default_dtype"] = "float64" mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" mace_params["config"] = tmp_path / "config.yaml" @@ -349,7 +349,7 @@ def test_run_train_multihead(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float32" + tmp_path / "MACE.model", device="cpu", default_dtype="float64" ) Es = [] From b52d20b79e6a63a82ad5c579aa40112024103cb5 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 2 Sep 2024 17:44:41 +0100 Subject: [PATCH 05/27] new interface passing old tests --- mace/calculators/mace.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 61b494a0..c4459788 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -237,7 +237,6 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if self.model_type in ["MACE", "EnergyDipoleMACE"]: batch = self._clone_batch(batch_base) node_heads = batch["head"][batch["batch"]] - print(node_heads) num_atoms_arange = torch.arange(batch["positions"].shape[0]) node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ num_atoms_arange, node_heads From a742297850c24602eb67b0405c4cd3917c176319 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 2 Sep 2024 18:52:20 +0100 Subject: [PATCH 06/27] linting and fixed preprocess data --- mace/calculators/mace.py | 4 +++- mace/cli/preprocess_data.py | 12 +++++----- mace/cli/run_train.py | 17 +++++++------- mace/data/__init__.py | 2 +- mace/data/atomic_data.py | 42 ++++++++++++++++++++++++++--------- mace/data/hdf5_dataset.py | 4 +++- mace/data/utils.py | 6 ++--- mace/tools/multihead_tools.py | 15 +++++++++---- mace/tools/scripts_utils.py | 35 ++++++++++++++++++++--------- tests/test_foundations.py | 2 +- tests/test_models.py | 4 ++-- tests/test_modules.py | 2 +- 12 files changed, 95 insertions(+), 50 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index c4459788..37b6af0c 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -198,7 +198,9 @@ def _create_result_tensors( return dict_of_tensors def _atoms_to_batch(self, atoms): - keyspec = data.KeySpecification(info_keys={}, arrays_keys={"charges": self.charges_key}) + keyspec = data.KeySpecification( + info_keys={}, arrays_keys={"charges": self.charges_key} + ) config = data.config_from_atoms(atoms, key_specification=keyspec) data_loader = torch_geometric.dataloader.DataLoader( dataset=[ diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index de34b1d4..68b72955 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -18,6 +18,7 @@ from mace import data, tools from mace.data.utils import save_configurations_as_HDF5 +from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.modules import compute_statistics from mace.tools import torch_geometric from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz @@ -129,6 +130,10 @@ def run(args: argparse.Namespace): new hdf5 file that is ready for training with on-the-fly dataloading """ + # currently support only command line property_key syntax + args.key_specification = KeySpecification() + update_keyspec_from_kwargs(args.key_specification, vars(args)) + # Setup tools.set_seeds(args.seed) random.seed(args.seed) @@ -162,12 +167,7 @@ def run(args: argparse.Namespace): config_type_weights=config_type_weights, test_path=args.test_file, seed=args.seed, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, + key_specification=args.key_specification, ) # Atomic number table diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 5075937f..58e0a37a 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -24,6 +24,7 @@ import mace from mace import data, tools from mace.calculators.foundations_models import mace_mp, mace_off +from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.tools import torch_geometric from mace.tools.model_script_utils import configure_model from mace.tools.multihead_tools import ( @@ -52,7 +53,6 @@ ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable -from mace.data import update_keyspec_from_kwargs, KeySpecification def main() -> None: @@ -71,7 +71,7 @@ def run(args: argparse.Namespace) -> None: args, input_log_messages = tools.check_args(args) # default keyspec to update using heads dictionary - args.key_specification = KeySpecification() + args.key_specification = KeySpecification() update_keyspec_from_kwargs(args.key_specification, vars(args)) if args.device == "xpu": @@ -156,12 +156,12 @@ def run(args: argparse.Namespace) -> None: args.multiheads_finetuning = False if args.heads is not None: - args.heads = ast.literal_eval(args.heads) # strings from command line + args.heads = ast.literal_eval(args.heads) else: args.heads = prepare_default_head(args) logging.info("===========LOADING INPUT DATA===========") - heads = list(args.heads.keys()) + heads = list(args.heads.keys()) # TODO: rename to heads_names logging.info(f"Using heads: {heads}") head_configs: List[HeadConfig] = [] for head, head_args in args.heads.items(): @@ -208,7 +208,7 @@ def run(args: argparse.Namespace) -> None: config_type_weights=config_type_weights, test_path=head_config.test_file, seed=args.seed, - key_specification=args.key_specification, + key_specification=head_config.key_specification, head_name=head_config.head_name, keep_isolated_atoms=head_config.keep_isolated_atoms, ) @@ -253,8 +253,8 @@ def run(args: argparse.Namespace) -> None: mp_keyspec = KeySpecification() update_keyspec_from_kwargs(mp_keyspec, vars(args)) mp_keyspec.update( - info_keys={"energy":"energy", "stress":"stress"}, - arrays_keys={"forces":"forces"}, + info_keys={"energy": "energy", "stress": "stress"}, + arrays_keys={"forces": "forces"}, ) head_config_pt = HeadConfig( head_name="pt_head", @@ -273,6 +273,7 @@ def run(args: argparse.Namespace) -> None: f"Using foundation model for multiheads finetuning with {args.pt_train_file}" ) heads = list(dict.fromkeys(["pt_head"] + heads)) + # TODO: new interface so that pretrained head has a seperate keyspec and does not rely on args collections, atomic_energies_dict = get_dataset_from_xyz( work_dir=args.work_dir, train_path=args.pt_train_file, @@ -374,8 +375,6 @@ def run(args: argparse.Namespace) -> None: for z in z_table.zs } - print('hey') - if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True diff --git a/mace/data/__init__.py b/mace/data/__init__.py index 52c56c4d..07999d39 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -4,6 +4,7 @@ from .utils import ( Configuration, Configurations, + KeySpecification, compute_average_E0s, config_from_atoms, config_from_atoms_list, @@ -13,7 +14,6 @@ save_configurations_as_HDF5, save_dataset_as_HDF5, test_config_types, - KeySpecification, update_keyspec_from_kwargs, ) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index fc43cdbb..1ac833c6 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -116,7 +116,7 @@ def from_config( z_table: AtomicNumberTable, cutoff: float, heads: Optional[list] = None, - **kwargs, + **kwargs, # pylint: disable=unused-argument ) -> "AtomicData": if heads is None: heads = ["default"] @@ -150,60 +150,80 @@ def from_config( ) energy_weight = ( - torch.tensor(config.property_weights.get("energy"), dtype=torch.get_default_dtype()) + torch.tensor( + config.property_weights.get("energy"), dtype=torch.get_default_dtype() + ) if config.property_weights.get("energy") is not None else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) forces_weight = ( - torch.tensor(config.property_weights.get("forces"), dtype=torch.get_default_dtype()) + torch.tensor( + config.property_weights.get("forces"), dtype=torch.get_default_dtype() + ) if config.property_weights.get("forces") is not None else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) stress_weight = ( - torch.tensor(config.property_weights.get("stress"), dtype=torch.get_default_dtype()) + torch.tensor( + config.property_weights.get("stress"), dtype=torch.get_default_dtype() + ) if config.property_weights.get("stress") is not None else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) virials_weight = ( - torch.tensor(config.property_weights.get("virials"), dtype=torch.get_default_dtype()) + torch.tensor( + config.property_weights.get("virials"), dtype=torch.get_default_dtype() + ) if config.property_weights.get("virials") is not None else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) forces = ( - torch.tensor(config.properties.get("forces"), dtype=torch.get_default_dtype()) + torch.tensor( + config.properties.get("forces"), dtype=torch.get_default_dtype() + ) if config.properties.get("forces") is not None else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype()) ) energy = ( - torch.tensor(config.properties.get("energy"), dtype=torch.get_default_dtype()) + torch.tensor( + config.properties.get("energy"), dtype=torch.get_default_dtype() + ) if config.properties.get("energy") is not None else torch.tensor(0.0, dtype=torch.get_default_dtype()) ) stress = ( voigt_to_matrix( - torch.tensor(config.properties.get("stress"), dtype=torch.get_default_dtype()) + torch.tensor( + config.properties.get("stress"), dtype=torch.get_default_dtype() + ) ).unsqueeze(0) if config.properties.get("stress") is not None else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) ) virials = ( voigt_to_matrix( - torch.tensor(config.properties.get("virials"), dtype=torch.get_default_dtype()) + torch.tensor( + config.properties.get("virials"), dtype=torch.get_default_dtype() + ) ).unsqueeze(0) if config.properties.get("virials") is not None else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) ) dipole = ( - torch.tensor(config.properties.get("dipole"), dtype=torch.get_default_dtype()).unsqueeze(0) + torch.tensor( + config.properties.get("dipole"), dtype=torch.get_default_dtype() + ).unsqueeze(0) if config.properties.get("dipole") is not None else torch.zeros(1, 3, dtype=torch.get_default_dtype()) ) charges = ( - torch.tensor(config.properties.get("charges"), dtype=torch.get_default_dtype()) + torch.tensor( + config.properties.get("charges"), dtype=torch.get_default_dtype() + ) if config.properties.get("charges") is not None else torch.zeros(num_atoms, dtype=torch.get_default_dtype()) ) diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 2ddbaf7e..f7755175 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -10,7 +10,9 @@ class HDF5Dataset(Dataset): - def __init__(self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs): + def __init__( + self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs + ): super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments self.file_path = file_path self._file = None diff --git a/mace/data/utils.py b/mace/data/utils.py index 933a08ae..53155813 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -6,7 +6,7 @@ import logging from dataclasses import dataclass, field -from typing import Dict, List, Optional, Sequence, Tuple, Any +from typing import Any, Dict, List, Optional, Sequence, Tuple import ase.data import ase.io @@ -133,7 +133,7 @@ def config_from_atoms( """Convert ase.Atoms to Configuration""" if config_type_weights is None: config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - + atomic_numbers = np.array( [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] ) @@ -159,7 +159,7 @@ def config_from_atoms( properties[name] = atoms.arrays.get(atoms_key, None) if not atoms_key in atoms.arrays: property_weights[name] = 0.0 - + if "head" in properties: del properties["head"] del property_weights["head"] diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 5fff4add..fb8191e8 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -3,18 +3,19 @@ import logging import os import urllib.request +from copy import deepcopy from typing import Any, Dict, List, Optional, Union import torch from mace.cli.fine_tuning_select import select_samples +from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.tools.scripts_utils import ( SubsetCollection, dict_to_namespace, get_dataset_from_xyz, ) -from mace.data import update_keyspec_from_kwargs, KeySpecification -from copy import deepcopy + @dataclasses.dataclass class HeadConfig: @@ -46,7 +47,9 @@ def dict_head_to_dataclass( # priority is global args < head property_key values < head info_keys+arrays_keys head_keyspec = deepcopy(args.key_specification) update_keyspec_from_kwargs(head_keyspec, head) - head_keyspec.update(info_keys=head.get("info_keys", {}), arrays_keys=head.get("arrays_keys", {})) + head_keyspec.update( + info_keys=head.get("info_keys", {}), arrays_keys=head.get("arrays_keys", {}) + ) return HeadConfig( head_name=head_name, @@ -79,6 +82,7 @@ def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: "test_dir": args.test_dir, "E0s": args.E0s, "statistics_file": args.statistics_file, + "key_specification": args.key_specification, "valid_fraction": args.valid_fraction, "config_type_weights": args.config_type_weights, "keep_isolated_atoms": args.keep_isolated_atoms, @@ -87,7 +91,10 @@ def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: def assemble_mp_data( - args: argparse.Namespace, tag: str, head_configs: List[HeadConfig], head_config_pt: HeadConfig + args: argparse.Namespace, + tag: str, + head_configs: List[HeadConfig], + head_config_pt: HeadConfig, ) -> Dict[str, Any]: try: checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 308d6246..b4b6ed5d 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -20,11 +20,10 @@ from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules, tools +from mace.data import KeySpecification from mace.tools import evaluate from mace.tools.train import SWAContainer -from mace.data import KeySpecification - @dataclasses.dataclass class SubsetCollection: @@ -54,8 +53,12 @@ def get_dataset_from_xyz( keep_isolated_atoms=keep_isolated_atoms, head_name=head_name, ) - num_energies = int(np.sum([config.property_weights["energy"] for config in all_train_configs])) - num_forces = int(np.sum([config.property_weights["forces"] for config in all_train_configs])) + num_energies = int( + np.sum([config.property_weights["energy"] for config in all_train_configs]) + ) + num_forces = int( + np.sum([config.property_weights["forces"] for config in all_train_configs]) + ) logging.info( f"Training set [{len(all_train_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{train_path}'" ) @@ -67,8 +70,12 @@ def get_dataset_from_xyz( extract_atomic_energies=False, head_name=head_name, ) - num_energies = int(np.sum([config.property_weights["energy"] for config in valid_configs])) - num_forces = int(np.sum([config.property_weights["forces"] for config in valid_configs])) + num_energies = int( + np.sum([config.property_weights["energy"] for config in valid_configs]) + ) + num_forces = int( + np.sum([config.property_weights["forces"] for config in valid_configs]) + ) logging.info( f"Training set [{len(valid_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{valid_path}'" ) @@ -77,8 +84,12 @@ def get_dataset_from_xyz( train_configs, valid_configs = data.random_train_valid_split( all_train_configs, valid_fraction, seed, work_dir ) - num_energies = int(np.sum([config.property_weights["energy"] for config in valid_configs])) - num_forces = int(np.sum([config.property_weights["forces"] for config in valid_configs])) + num_energies = int( + np.sum([config.property_weights["energy"] for config in valid_configs]) + ) + num_forces = int( + np.sum([config.property_weights["forces"] for config in valid_configs]) + ) logging.info( f"Validation set contains {len(valid_configs)} configs, [{num_energies} energy, {num_forces} forces]" ) @@ -97,8 +108,12 @@ def get_dataset_from_xyz( f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" ) for name, tmp_configs in test_configs: - num_energies = int(np.sum([config.property_weights["energy"] for config in tmp_configs])) - num_forces = int(np.sum([config.property_weights["forces"] for config in tmp_configs])) + num_energies = int( + np.sum([config.property_weights["energy"] for config in tmp_configs]) + ) + num_forces = int( + np.sum([config.property_weights["forces"] for config in tmp_configs]) + ) logging.info( f"{name}: {len(tmp_configs)} configs, {num_energies} energy, {num_forces} forces" ) diff --git a/tests/test_foundations.py b/tests/test_foundations.py index 33d338e2..846d85cd 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -129,7 +129,7 @@ def test_multi_reference(): "charges": 1.0, "dipole": 1.0, }, - head='MP2' + head="MP2", ) table_multi = tools.AtomicNumberTable([1, 6, 8]) atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) diff --git a/tests/test_models.py b/tests/test_models.py index e114e308..7a39b22a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,7 @@ ] ), "energy": -1.5, - "charges":np.array([-2.0, 1.0, 1.0]), + "charges": np.array([-2.0, 1.0, 1.0]), "dipole": np.array([-1.5, 1.5, 2.0]), }, property_weights={ @@ -52,7 +52,7 @@ ] ), "energy": -1.5, - "charges":np.array([-2.0, 1.0, 1.0]), + "charges": np.array([-2.0, 1.0, 1.0]), "dipole": np.array([-1.5, 1.5, 2.0]), }, property_weights={ diff --git a/tests/test_modules.py b/tests/test_modules.py index 3097c735..d5df74df 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -32,7 +32,7 @@ ] ), "energy": -1.5, - "stress":np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + "stress": np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), }, property_weights={ "forces": 1.0, From ddfb31d7e41b61ab9318647222c672dc9fcbca12 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 2 Sep 2024 19:05:12 +0100 Subject: [PATCH 07/27] more linting --- mace/cli/preprocess_data.py | 2 +- mace/data/atomic_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 68b72955..540b39ce 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -17,8 +17,8 @@ import tqdm from mace import data, tools -from mace.data.utils import save_configurations_as_HDF5 from mace.data import KeySpecification, update_keyspec_from_kwargs +from mace.data.utils import save_configurations_as_HDF5 from mace.modules import compute_statistics from mace.tools import torch_geometric from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 1ac833c6..9b2e86f8 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -116,7 +116,7 @@ def from_config( z_table: AtomicNumberTable, cutoff: float, heads: Optional[list] = None, - **kwargs, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument ) -> "AtomicData": if heads is None: heads = ["default"] From 4253216a2bb36b752c3cac4c005a5411998507ff Mon Sep 17 00:00:00 2001 From: WillBaldwin0 <47224653+WillBaldwin0@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:37:52 +0100 Subject: [PATCH 08/27] Update unittest.yaml testing python 3.11 --- .github/workflows/unittest.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 857bf894..f75da0c5 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -12,7 +12,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.11" cache: "pip" - name: Install requirements From 8803c481eb888b8f5cae1a502ce9a535c13756e2 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Tue, 3 Sep 2024 12:03:54 +0100 Subject: [PATCH 09/27] fix key overwriting and unittests --- .github/workflows/unittest.yaml | 2 +- mace/data/utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index f75da0c5..857bf894 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -12,7 +12,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.10" cache: "pip" - name: Install requirements diff --git a/mace/data/utils.py b/mace/data/utils.py index 53155813..9a2f73a5 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -210,7 +210,7 @@ def load_from_xyz( logging.warning( "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." ) - energy_key = "REF_energy" + key_specification.info_keys["energy"] = "REF_energy" for atoms in atoms_list: try: atoms.info["REF_energy"] = atoms.get_potential_energy() @@ -221,7 +221,7 @@ def load_from_xyz( logging.warning( "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." ) - forces_key = "REF_forces" + key_specification.info_keys["forces"] = "REF_forces" for atoms in atoms_list: try: atoms.arrays["REF_forces"] = atoms.get_forces() @@ -232,7 +232,7 @@ def load_from_xyz( logging.warning( "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." ) - stress_key = "REF_stress" + key_specification.info_keys["stress"] = "REF_stress" for atoms in atoms_list: try: atoms.info["REF_stress"] = atoms.get_stress() From 6e3199508b94667c7377e7ba2a95ad822b19cb84 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Tue, 3 Sep 2024 12:30:28 +0100 Subject: [PATCH 10/27] small bug in settings REF_forces --- mace/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index 9a2f73a5..01364840 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -221,7 +221,7 @@ def load_from_xyz( logging.warning( "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." ) - key_specification.info_keys["forces"] = "REF_forces" + key_specification.arrays_keys["forces"] = "REF_forces" for atoms in atoms_list: try: atoms.arrays["REF_forces"] = atoms.get_forces() From dfefca0ec242084be72471368324606a22a968da Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Wed, 4 Sep 2024 17:45:56 +0100 Subject: [PATCH 11/27] remove head key and some minor fixes --- mace/cli/preprocess_data.py | 1 + mace/cli/run_train.py | 3 +-- mace/data/utils.py | 13 ++++++------- mace/tools/multihead_tools.py | 4 +++- mace/tools/scripts_utils.py | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 540b39ce..46806850 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -168,6 +168,7 @@ def run(args: argparse.Namespace): test_path=args.test_file, seed=args.seed, key_specification=args.key_specification, + head_name=None, ) # Atomic number table diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 58e0a37a..2d3f97d7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -161,7 +161,7 @@ def run(args: argparse.Namespace) -> None: args.heads = prepare_default_head(args) logging.info("===========LOADING INPUT DATA===========") - heads = list(args.heads.keys()) # TODO: rename to heads_names + heads = list(args.heads.keys()) logging.info(f"Using heads: {heads}") head_configs: List[HeadConfig] = [] for head, head_args in args.heads.items(): @@ -273,7 +273,6 @@ def run(args: argparse.Namespace) -> None: f"Using foundation model for multiheads finetuning with {args.pt_train_file}" ) heads = list(dict.fromkeys(["pt_head"] + heads)) - # TODO: new interface so that pretrained head has a seperate keyspec and does not rely on args collections, atomic_energies_dict = get_dataset_from_xyz( work_dir=args.work_dir, train_path=args.pt_train_file, diff --git a/mace/data/utils.py b/mace/data/utils.py index 01364840..7085a2f0 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -108,6 +108,7 @@ def config_from_atoms_list( atoms_list: List[ase.Atoms], key_specification: KeySpecification, config_type_weights: Optional[Dict[str, float]] = None, + head_name: str = "Default" ) -> Configurations: """Convert list of ase.Atoms into Configurations""" if config_type_weights is None: @@ -120,6 +121,7 @@ def config_from_atoms_list( atoms, key_specification=key_specification, config_type_weights=config_type_weights, + head_name=head_name ) ) return all_configs @@ -129,6 +131,7 @@ def config_from_atoms( atoms: ase.Atoms, key_specification: KeySpecification = KeySpecification(), config_type_weights: Optional[Dict[str, float]] = None, + head_name: str = "Default" ) -> Configuration: """Convert ase.Atoms to Configuration""" if config_type_weights is None: @@ -143,8 +146,7 @@ def config_from_atoms( weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( config_type, 1.0 ) - head_key = key_specification.info_keys.get("head", "head") - head = atoms.info.get(head_key, "Default") + properties = {} property_weights = {} for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): @@ -160,17 +162,13 @@ def config_from_atoms( if not atoms_key in atoms.arrays: property_weights[name] = 0.0 - if "head" in properties: - del properties["head"] - del property_weights["head"] - return Configuration( atomic_numbers=atomic_numbers, positions=atoms.get_positions(), properties=properties, weight=weight, property_weights=property_weights, - head=head, + head=head_name, config_type=config_type, pbc=pbc, cell=cell, @@ -273,6 +271,7 @@ def load_from_xyz( atoms_list, config_type_weights=config_type_weights, key_specification=key_specification, + head_name=head_name, ) return atomic_energies_dict, configs diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index fb8191e8..aba12272 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -50,7 +50,9 @@ def dict_head_to_dataclass( head_keyspec.update( info_keys=head.get("info_keys", {}), arrays_keys=head.get("arrays_keys", {}) ) - + # parser+head args that have no defaults but are required + if (args.train_file is None) and (head.get("train_file", None) is None) : + raise ValueError("train file is not set in the head config yaml or via command line args") return HeadConfig( head_name=head_name, train_file=head.get("train_file", args.train_file), diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index b4b6ed5d..17c6df5e 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -77,7 +77,7 @@ def get_dataset_from_xyz( np.sum([config.property_weights["forces"] for config in valid_configs]) ) logging.info( - f"Training set [{len(valid_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{valid_path}'" + f"Validation set [{len(valid_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{valid_path}'" ) train_configs = all_train_configs else: From 2e4d524935f46aa964a657f0db09cc8e1098849e Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Thu, 5 Sep 2024 11:22:45 +0100 Subject: [PATCH 12/27] default to Default for heads --- mace/calculators/mace.py | 15 +++++++++++---- mace/data/atomic_data.py | 2 +- mace/modules/models.py | 2 +- mace/tools/multihead_tools.py | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 37b6af0c..e621af7f 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -145,10 +145,17 @@ def __init__( [int(z) for z in self.models[0].atomic_numbers] ) self.charges_key = charges_key + try: - self.heads = self.models[0].heads + self.available_heads = self.models[0].heads except AttributeError: - self.heads = ["Default"] + self.available_heads = ["Default"] + self.head = kwargs.get('head', 'Default') + assert self.head in self.available_heads, f'specified head {self.head}, but model available model heads are {heads}' + + print('using head', self.head, 'out of', self.available_heads) + + model_dtype = get_model_dtype(self.models[0]) if default_dtype == "": print( @@ -201,11 +208,11 @@ def _atoms_to_batch(self, atoms): keyspec = data.KeySpecification( info_keys={}, arrays_keys={"charges": self.charges_key} ) - config = data.config_from_atoms(atoms, key_specification=keyspec) + config = data.config_from_atoms(atoms, key_specification=keyspec, head_name=self.head) data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads + config, z_table=self.z_table, cutoff=self.r_max, heads=self.available_heads ) ], batch_size=1, diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 9b2e86f8..5217e314 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -119,7 +119,7 @@ def from_config( **kwargs, # pylint: disable=unused-argument ) -> "AtomicData": if heads is None: - heads = ["default"] + heads = ["Default"] edge_index, shifts, unit_shifts = get_neighborhood( positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell ) diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab43..05578ac8 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -74,7 +74,7 @@ def __init__( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) if heads is None: - heads = ["default"] + heads = ["Default"] self.heads = heads if isinstance(correlation, int): correlation = [correlation] * num_interactions diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index aba12272..63c0cec6 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -77,7 +77,7 @@ def dict_head_to_dataclass( def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: return { - "default": { + "Default": { "train_file": args.train_file, "valid_file": args.valid_file, "test_file": args.test_file, From 1390a0b46f035d640e97288985c60f1a1975d2cc Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Fri, 6 Sep 2024 09:36:44 +0100 Subject: [PATCH 13/27] added new test and fix calculator --- mace/calculators/mace.py | 20 +- mace/data/atomic_data.py | 2 +- mace/data/utils.py | 6 +- mace/tools/multihead_tools.py | 6 +- tests/test_run_train.py | 10 +- tests/test_run_train_allkeys.py | 431 ++++++++++++++++++++++++++++++++ 6 files changed, 457 insertions(+), 18 deletions(-) create mode 100644 tests/test_run_train_allkeys.py diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index e621af7f..8c2c1f09 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -150,12 +150,13 @@ def __init__( self.available_heads = self.models[0].heads except AttributeError: self.available_heads = ["Default"] - self.head = kwargs.get('head', 'Default') - assert self.head in self.available_heads, f'specified head {self.head}, but model available model heads are {heads}' + self.head = kwargs.get("head", "Default") + assert ( + self.head in self.available_heads + ), f"specified head {self.head}, but model available model heads are {self.available_heads}" + + print("using head", self.head, "out of", self.available_heads) - print('using head', self.head, 'out of', self.available_heads) - - model_dtype = get_model_dtype(self.models[0]) if default_dtype == "": print( @@ -208,11 +209,16 @@ def _atoms_to_batch(self, atoms): keyspec = data.KeySpecification( info_keys={}, arrays_keys={"charges": self.charges_key} ) - config = data.config_from_atoms(atoms, key_specification=keyspec, head_name=self.head) + config = data.config_from_atoms( + atoms, key_specification=keyspec, head_name=self.head + ) data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max, heads=self.available_heads + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.available_heads, ) ], batch_size=1, diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 15734b6c..07ce39c2 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -119,8 +119,8 @@ def from_config( **kwargs, # pylint: disable=unused-argument ) -> "AtomicData": if heads is None: - edge_index, shifts, unit_shifts, cell = get_neighborhood( heads = ["Default"] + edge_index, shifts, unit_shifts, cell = get_neighborhood( positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell ) indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) diff --git a/mace/data/utils.py b/mace/data/utils.py index 7085a2f0..ee4678b8 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -108,7 +108,7 @@ def config_from_atoms_list( atoms_list: List[ase.Atoms], key_specification: KeySpecification, config_type_weights: Optional[Dict[str, float]] = None, - head_name: str = "Default" + head_name: str = "Default", ) -> Configurations: """Convert list of ase.Atoms into Configurations""" if config_type_weights is None: @@ -121,7 +121,7 @@ def config_from_atoms_list( atoms, key_specification=key_specification, config_type_weights=config_type_weights, - head_name=head_name + head_name=head_name, ) ) return all_configs @@ -131,7 +131,7 @@ def config_from_atoms( atoms: ase.Atoms, key_specification: KeySpecification = KeySpecification(), config_type_weights: Optional[Dict[str, float]] = None, - head_name: str = "Default" + head_name: str = "Default", ) -> Configuration: """Convert ase.Atoms to Configuration""" if config_type_weights is None: diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 63c0cec6..369d1e10 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -51,8 +51,10 @@ def dict_head_to_dataclass( info_keys=head.get("info_keys", {}), arrays_keys=head.get("arrays_keys", {}) ) # parser+head args that have no defaults but are required - if (args.train_file is None) and (head.get("train_file", None) is None) : - raise ValueError("train file is not set in the head config yaml or via command line args") + if (args.train_file is None) and (head.get("train_file", None) is None): + raise ValueError( + "train file is not set in the head config yaml or via command line args" + ) return HeadConfig( head_name=head_name, train_file=head.get("train_file", args.train_file), diff --git a/tests/test_run_train.py b/tests/test_run_train.py index fe6c8c46..f751c02f 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -349,7 +349,7 @@ def test_run_train_multihead(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float64" + tmp_path / "MACE.model", device="cpu", default_dtype="float64", head="CCD" ) Es = [] @@ -535,12 +535,12 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 - calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - Es = [] for at in fitting_configs: + config_head = at.info.get('head', 'MP2') + calc = MACECalculator( + tmp_path / "MACE.model", device="cpu", default_dtype="float64", head=config_head + ) at.calc = calc Es.append(at.get_potential_energy()) diff --git a/tests/test_run_train_allkeys.py b/tests/test_run_train_allkeys.py new file mode 100644 index 00000000..c3632a40 --- /dev/null +++ b/tests/test_run_train_allkeys.py @@ -0,0 +1,431 @@ +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +from ase.atoms import Atoms +from copy import deepcopy + +from mace.calculators.mace import MACECalculator +import mace +np.random.seed(0) + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, +} + + +def configs_numbered_keys(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + + energies = list(np.random.normal(0.1, size=15)) + forces = list(np.random.normal(0.1, size=(15, 3, 3))) + + trial_configs_lists = [] + # some keys present, some not + keys_to_use = ["REF_energy"] + \ + ["2_energy"]*2 + \ + ["3_energy"]*3 + \ + ["4_energy"]*4 + \ + ["5_energy"]*5 + + force_keys_to_use = ["REF_forces"] + \ + ["2_forces"]*2 + \ + ["3_forces"]*3 + \ + ["4_forces"]*4 + \ + ["5_forces"]*5 + + for ind in range(15): + c = deepcopy(water) + c.info[keys_to_use[ind]] = energies[ind] + c.arrays[force_keys_to_use[ind]] = forces[ind] + c.positions += np.random.normal(0.1, size=(3,3)) + trial_configs_lists.append(c) + + return trial_configs_lists + + +def trial_yamls_and_and_expected(): + yamls = {} + command_line_kwargs = {"energy_key": "2_energy", "forces_key": "2_forces"} + + yamls["no_heads"] = {} + + yamls["one_head_no_dicts"] = { + "heads": { + "Default": { + "energy_key": "3_energy", + } + } + } + + yamls["one_head_with_dicts"] = { + "heads": { + "Default": { + "info_keys": { + "energy": "3_energy", + }, + "arrays_keys": { + "forces": "3_forces", + }, + } + } + } + + yamls["two_heads_no_dicts"] = { + "heads": { + "dft": { + "train_file": "fit_multihead_dft.xyz", + "energy_key": "3_energy", + }, + "mp2": { + "train_file": "fit_multihead_mp2.xyz", + "energy_key": "4_energy", + }, + } + } + + yamls["two_heads_mixed"] = { + "heads": { + "dft": { + "train_file": "fit_multihead_dft.xyz", + "info_keys": { + "energy": "3_energy", + }, + "arrays_keys": { + "forces": "3_forces", + }, + "forces_key": "4_forces", + }, + "mp2": { + "train_file": "fit_multihead_mp2.xyz", + "energy_key": "4_energy", + }, + } + } + all_arg_sets = { + "with_command_line": { + key: {**command_line_kwargs, **value} for key, value in yamls.items() + }, + "without_command_line": {key: value for key, value in yamls.items()}, + } + + all_expected_outputs = { + "with_command_line": { + "no_heads": [ + 1.0037831178668188, + 1.0183291323603265, + 1.0120784084221528, + 0.9935695881012243, + 1.0021641561865526, + 0.9999135609205868, + 0.9809440616323108, + 1.0025784765050076, + 1.0017901145495376, + 1.0136913185404515, + 1.006798563238269, + 1.0187758397828384, + 1.0180201540775071, + 1.0132368725061702, + 0.9998734173248169, + ], + "one_head_no_dicts": [ + 1.0000166263499473, + 1.0021620416915131, + 1.0046772896383978, + 1.001465441141607, + 1.0055517812192685, + 1.0015992637882436, + 1.0020402319259156, + 1.0054369609690694, + 1.0048820789691186, + 1.004069245195459, + 1.0036930315433792, + 1.0045657994185517, + 1.0049657202069904, + 1.0054495991318766, + 1.0059574240719107, + ], + "one_head_with_dicts": [ + 0.9824761809087968, + 0.982723323954806, + 0.9804037844582393, + 0.9892979892015554, + 0.990123250174031, + 0.9872765633686582, + 0.9792985720223041, + 0.9834849185579561, + 0.9855709706241268, + 0.9838176625524332, + 0.9802380433794929, + 0.9798924747115749, + 0.9941246312362003, + 0.9843619552495816, + 1.0234402440454935, + ], + "two_heads_no_dicts": [ + 0.9533172241443488, + 0.971143149409332, + 0.9591034423596022, + 0.9259180388268078, + 0.9866672025915887, + 0.9468387512978088, + 0.972806503955744, + 0.9268821579802152, + 0.9399783569634511, + 0.9566909477955546, + 1.0280484765877604, + 0.9638781804485581, + 0.9386762390303685, + 0.9513720471682103, + 1.061099519224079, + ], + "two_heads_mixed": [ + 1.0008821794271117, + 0.9921658975489234, + 1.0128605897789047, + 1.0177680320432732, + 1.0040635968372489, + 1.0134535284156263, + 0.9900156994903402, + 0.9950077226207892, + 0.9931748657782218, + 0.9970869871816835, + 1.0036266515981311, + 0.9882332649269495, + 0.9973620987054619, + 1.0089283927259747, + 0.9984375026446699, + ], + }, + "without_command_line": { + "no_heads": [ + 0.9723249939003304, + 0.99830004939027, + 0.9976857883262907, + 1.0026915904907623, + 0.9986047122447201, + 1.0056392530400915, + 0.9955992271879338, + 0.9925618058915322, + 0.9992873743817391, + 1.0017751144824205, + 0.9965424145952742, + 0.9980104982304532, + 0.996970035434205, + 1.0017462160896793, + 1.00453025524217, + ], + "one_head_no_dicts": [ + 0.9668728328024694, + 0.9559554052674338, + 0.9558003309868804, + 0.9568681942948057, + 0.9471374531635678, + 0.9573665902279203, + 0.9509504944430629, + 0.9449430732494284, + 0.9487872001503757, + 0.9515435134805473, + 0.9616246560028083, + 0.9652201708552365, + 0.9518567860504985, + 0.9695448453855497, + 0.9595931614125687, + ], + "one_head_with_dicts": [ + 0.9904238487805224, + 0.9787489784129528, + 0.9980000798872206, + 1.0081047579760913, + 0.970990405481672, + 1.0296635919726917, + 1.0070991842774164, + 0.9977357706770508, + 0.9729041794133619, + 0.9952167479342705, + 1.0256795692987708, + 1.0005027614317226, + 1.0042896304620599, + 0.9933015438418198, + 0.9941762126172496, + ], + "two_heads_no_dicts": [ + 0.8234141049979373, + 0.8486132642907047, + 0.8761921831858267, + 0.8086446850523645, + 0.8185616207749478, + 0.8349295066652644, + 0.8695339796701849, + 0.8783625449137391, + 0.8513575832201994, + 0.8428073015147357, + 0.8514345324682252, + 0.8774982178381736, + 0.8724648944295484, + 0.9071025824523504, + 0.8671562526370659, + ], + "two_heads_mixed": [ + 1.0142275963817828, + 0.9252946269851097, + 0.9905802472120683, + 1.0104854763203601, + 1.0627569806879018, + 0.894635070244004, + 0.9570335273959514, + 0.9917699286224028, + 0.9731498108644769, + 1.02712188692559, + 1.0255958579172193, + 1.0134291318470228, + 0.9601947878290134, + 0.9593860448787849, + 1.0044099804202045, + ], + }, + } + + + list_of_all = [] + for key, value in all_arg_sets.items(): + print(key) + for key2, value2 in value.items(): + print(' ', key2) + list_of_all.append((value2, (key, key2), np.asarray(all_expected_outputs[key][key2]))) + + return list_of_all + + +def dict_to_yaml_str(data, indent=0): + yaml_str = "" + for key, value in data.items(): + yaml_str += " " * indent + str(key) + ":" + if isinstance(value, dict): + yaml_str += "\n" + dict_to_yaml_str(value, indent + 2) + else: + yaml_str += " " + str(value) + "\n" + return yaml_str + + +_trial_yamls_and_and_expected = trial_yamls_and_and_expected() + +@pytest.mark.parametrize("yaml_contents, name, expected_value", _trial_yamls_and_and_expected) +def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value, debug_test=False): + fitting_configs = configs_numbered_keys() + + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs) + ase.io.write(tmp_path / "duplicated_fit_multihead_dft.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["loss"] = "weighted" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 1 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + mace_params["train_file"] = "fit_multihead_dft.xyz" + mace_params["E0s"] = "{1:0.0,8:1.0}" + mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz" + del mace_params["valid_fraction"] + mace_params["max_num_epochs"] = 1 # many tests to do + del mace_params["energy_key"] + del mace_params["forces_key"] + del mace_params["stress_key"] + + mace_params["name"] = "MACE_" + + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(dict_to_yaml_str(yaml_contents)) + if len(yaml_contents) > 0: + mace_params["config"] = str(tmp_path / "config.yaml") + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + if debug_test: + new_cmd = cmd.replace('--', '\n--') + print('calling run train with {name}') + print('command line args:\n', new_cmd) + print('config.yaml:\n', dict_to_yaml_str(yaml_contents), flush=True) + + p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True) + assert p.returncode == 0 + + if 'heads' in yaml_contents: + headname = list(yaml_contents['heads'].keys())[0] + else: + headname = 'Default' + + calc = MACECalculator( + tmp_path / "MACE_.model", device="cpu", default_dtype="float64", head=headname + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print(np.asarray(Es)) + print(expected_value) + print(type(np.asarray(Es))) + print(type(expected_value)) + assert np.allclose(np.asarray(Es), expected_value) \ No newline at end of file From 339231ea517ea9f6aa460247ab13ca773fb309a3 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Fri, 6 Sep 2024 12:33:10 +0100 Subject: [PATCH 14/27] fix tests seed --- tests/test_run_train.py | 7 +- tests/test_run_train_allkeys.py | 365 +++++++++++++++++--------------- 2 files changed, 198 insertions(+), 174 deletions(-) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index f751c02f..c99bdf83 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -537,9 +537,12 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): Es = [] for at in fitting_configs: - config_head = at.info.get('head', 'MP2') + config_head = at.info.get("head", "MP2") calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float64", head=config_head + tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, ) at.calc = calc Es.append(at.get_potential_energy()) diff --git a/tests/test_run_train_allkeys.py b/tests/test_run_train_allkeys.py index c3632a40..a6e76082 100644 --- a/tests/test_run_train_allkeys.py +++ b/tests/test_run_train_allkeys.py @@ -1,22 +1,19 @@ import os import subprocess import sys +from copy import deepcopy from pathlib import Path import ase.io import numpy as np import pytest from ase.atoms import Atoms -from copy import deepcopy from mace.calculators.mace import MACECalculator -import mace -np.random.seed(0) run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - _mace_params = { "name": "MACE", "valid_fraction": 0.05, @@ -44,6 +41,7 @@ def configs_numbered_keys(): + np.random.seed(0) water = Atoms( numbers=[8, 1, 1], positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], @@ -56,23 +54,27 @@ def configs_numbered_keys(): trial_configs_lists = [] # some keys present, some not - keys_to_use = ["REF_energy"] + \ - ["2_energy"]*2 + \ - ["3_energy"]*3 + \ - ["4_energy"]*4 + \ - ["5_energy"]*5 - - force_keys_to_use = ["REF_forces"] + \ - ["2_forces"]*2 + \ - ["3_forces"]*3 + \ - ["4_forces"]*4 + \ - ["5_forces"]*5 + keys_to_use = ( + ["REF_energy"] + + ["2_energy"] * 2 + + ["3_energy"] * 3 + + ["4_energy"] * 4 + + ["5_energy"] * 5 + ) + + force_keys_to_use = ( + ["REF_forces"] + + ["2_forces"] * 2 + + ["3_forces"] * 3 + + ["4_forces"] * 4 + + ["5_forces"] * 5 + ) for ind in range(15): c = deepcopy(water) c.info[keys_to_use[ind]] = energies[ind] c.arrays[force_keys_to_use[ind]] = forces[ind] - c.positions += np.random.normal(0.1, size=(3,3)) + c.positions += np.random.normal(0.1, size=(3, 3)) trial_configs_lists.append(c) return trial_configs_lists @@ -140,9 +142,9 @@ def trial_yamls_and_and_expected(): "with_command_line": { key: {**command_line_kwargs, **value} for key, value in yamls.items() }, - "without_command_line": {key: value for key, value in yamls.items()}, + "without_command_line": yamls, } - + all_expected_outputs = { "with_command_line": { "no_heads": [ @@ -163,170 +165,171 @@ def trial_yamls_and_and_expected(): 0.9998734173248169, ], "one_head_no_dicts": [ - 1.0000166263499473, - 1.0021620416915131, - 1.0046772896383978, - 1.001465441141607, - 1.0055517812192685, - 1.0015992637882436, - 1.0020402319259156, - 1.0054369609690694, - 1.0048820789691186, - 1.004069245195459, - 1.0036930315433792, - 1.0045657994185517, - 1.0049657202069904, - 1.0054495991318766, - 1.0059574240719107, + 1.0028437510688613, + 1.0514693378041775, + 1.059933403321331, + 1.034719940573569, + 1.0438040675561824, + 1.019719477728329, + 0.9841759692947915, + 1.0435266573857496, + 1.0339501989779065, + 1.0501795448530264, + 1.0402594216704781, + 1.0604998765679152, + 1.0633411200246015, + 1.0539071190201297, + 1.0393496428177804, ], "one_head_with_dicts": [ - 0.9824761809087968, - 0.982723323954806, - 0.9804037844582393, - 0.9892979892015554, - 0.990123250174031, - 0.9872765633686582, - 0.9792985720223041, - 0.9834849185579561, - 0.9855709706241268, - 0.9838176625524332, - 0.9802380433794929, - 0.9798924747115749, - 0.9941246312362003, - 0.9843619552495816, - 1.0234402440454935, + 0.8638341551096959, + 1.0078341354784144, + 1.0149701178418595, + 0.9945723048460148, + 1.0184158011731292, + 0.9992135295205004, + 0.8943420783639198, + 1.0327920054084088, + 0.9905731198078909, + 0.9838325204450648, + 1.0018725575620482, + 1.007263052421034, + 1.0335213929231966, + 1.0033503312511205, + 1.0174433894759563, ], "two_heads_no_dicts": [ - 0.9533172241443488, - 0.971143149409332, - 0.9591034423596022, - 0.9259180388268078, - 0.9866672025915887, - 0.9468387512978088, - 0.972806503955744, - 0.9268821579802152, - 0.9399783569634511, - 0.9566909477955546, - 1.0280484765877604, - 0.9638781804485581, - 0.9386762390303685, - 0.9513720471682103, - 1.061099519224079, + 0.9836377578288774, + 1.0196844186291318, + 1.0151628222871238, + 0.957307281711648, + 0.985574141310865, + 0.9629670134047853, + 0.9242583185138095, + 0.9807770070311039, + 0.9973679440479541, + 1.0221127246963275, + 1.0031807967874216, + 1.0358701219543687, + 1.0434208761164758, + 1.0235606028124515, + 0.9797494630655053, ], "two_heads_mixed": [ - 1.0008821794271117, - 0.9921658975489234, - 1.0128605897789047, - 1.0177680320432732, - 1.0040635968372489, - 1.0134535284156263, - 0.9900156994903402, - 0.9950077226207892, - 0.9931748657782218, - 0.9970869871816835, - 1.0036266515981311, - 0.9882332649269495, - 0.9973620987054619, - 1.0089283927259747, - 0.9984375026446699, + 0.8664108574741868, + 0.9907166576278023, + 1.0051969372365164, + 0.978702477000018, + 1.025500166764692, + 0.9940095566375018, + 0.9034029726954119, + 1.0391739502744488, + 0.9717327061183668, + 0.972292103670355, + 1.0012510461663253, + 0.9978051155885286, + 1.0378611651753475, + 1.0003207628186224, + 1.0209509292189651, ], }, "without_command_line": { "no_heads": [ - 0.9723249939003304, - 0.99830004939027, - 0.9976857883262907, - 1.0026915904907623, - 0.9986047122447201, - 1.0056392530400915, - 0.9955992271879338, - 0.9925618058915322, - 0.9992873743817391, - 1.0017751144824205, - 0.9965424145952742, - 0.9980104982304532, - 0.996970035434205, - 1.0017462160896793, - 1.00453025524217, + 0.9352605307451007, + 0.991084559389268, + 0.9940350095024881, + 0.9953849198103668, + 0.9954705498032904, + 0.9964815693808411, + 0.9663142667436776, + 0.9947223808739147, + 0.9897776682803257, + 0.989027769690667, + 0.9910280920241263, + 0.992067980667518, + 0.9917276132506404, + 0.9902848752169671, + 0.9928585982942544, ], "one_head_no_dicts": [ - 0.9668728328024694, - 0.9559554052674338, - 0.9558003309868804, - 0.9568681942948057, - 0.9471374531635678, - 0.9573665902279203, - 0.9509504944430629, - 0.9449430732494284, - 0.9487872001503757, - 0.9515435134805473, - 0.9616246560028083, - 0.9652201708552365, - 0.9518567860504985, - 0.9695448453855497, - 0.9595931614125687, + 0.9425342207393083, + 1.0149788456087416, + 1.0249228965652788, + 1.0247924743285792, + 1.02732103964481, + 1.0168852937950326, + 0.9771283495170653, + 1.0261776335561517, + 1.0130461033368028, + 1.0162619153561783, + 1.019995179866916, + 1.0209512298344965, + 1.0219971755636952, + 1.0195791901659124, + 1.0234662527729408, ], "one_head_with_dicts": [ - 0.9904238487805224, - 0.9787489784129528, - 0.9980000798872206, - 1.0081047579760913, - 0.970990405481672, - 1.0296635919726917, - 1.0070991842774164, - 0.9977357706770508, - 0.9729041794133619, - 0.9952167479342705, - 1.0256795692987708, - 1.0005027614317226, - 1.0042896304620599, - 0.9933015438418198, - 0.9941762126172496, + 0.8638341551096959, + 1.0078341354784144, + 1.0149701178418595, + 0.9945723048460148, + 1.0184158011731292, + 0.9992135295205004, + 0.8943420783639198, + 1.0327920054084088, + 0.9905731198078909, + 0.9838325204450648, + 1.0018725575620482, + 1.007263052421034, + 1.0335213929231966, + 1.0033503312511205, + 1.0174433894759563, ], "two_heads_no_dicts": [ - 0.8234141049979373, - 0.8486132642907047, - 0.8761921831858267, - 0.8086446850523645, - 0.8185616207749478, - 0.8349295066652644, - 0.8695339796701849, - 0.8783625449137391, - 0.8513575832201994, - 0.8428073015147357, - 0.8514345324682252, - 0.8774982178381736, - 0.8724648944295484, - 0.9071025824523504, - 0.8671562526370659, + 0.9933763730233168, + 0.9986480398559268, + 1.0042486164355315, + 1.0025568793877726, + 1.0032598081704625, + 0.9926714183717912, + 0.9920385249670881, + 1.0020278841030676, + 1.0012474150830537, + 1.0039289677261019, + 1.0022718878661814, + 1.003586385624809, + 1.003436450009097, + 1.003805673887942, + 1.001450261102316, ], "two_heads_mixed": [ - 1.0142275963817828, - 0.9252946269851097, - 0.9905802472120683, - 1.0104854763203601, - 1.0627569806879018, - 0.894635070244004, - 0.9570335273959514, - 0.9917699286224028, - 0.9731498108644769, - 1.02712188692559, - 1.0255958579172193, - 1.0134291318470228, - 0.9601947878290134, - 0.9593860448787849, - 1.0044099804202045, + 0.8781767864616707, + 0.9843563603794138, + 1.0145197579049248, + 0.9835060778675391, + 1.0419060462994596, + 0.9917393978520056, + 0.9091521032773944, + 1.0605463095070453, + 0.9685381713826684, + 0.9866493058823766, + 1.00305061187164, + 1.0051273128414386, + 1.037964258398104, + 1.0106663924241408, + 1.0274351814133602, ], }, } - list_of_all = [] for key, value in all_arg_sets.items(): print(key) for key2, value2 in value.items(): - print(' ', key2) - list_of_all.append((value2, (key, key2), np.asarray(all_expected_outputs[key][key2]))) + print(" ", key2) + list_of_all.append( + (value2, (key, key2), np.asarray(all_expected_outputs[key][key2])) + ) return list_of_all @@ -344,8 +347,13 @@ def dict_to_yaml_str(data, indent=0): _trial_yamls_and_and_expected = trial_yamls_and_and_expected() -@pytest.mark.parametrize("yaml_contents, name, expected_value", _trial_yamls_and_and_expected) -def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value, debug_test=False): + +@pytest.mark.parametrize( + "yaml_contents, name, expected_value", _trial_yamls_and_and_expected +) +def test_key_specification_methods( + tmp_path, yaml_contents, name, expected_value, debug_test=False +): fitting_configs = configs_numbered_keys() ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) @@ -370,7 +378,7 @@ def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value mace_params["E0s"] = "{1:0.0,8:1.0}" mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz" del mace_params["valid_fraction"] - mace_params["max_num_epochs"] = 1 # many tests to do + mace_params["max_num_epochs"] = 1 # many tests to do del mace_params["energy_key"] del mace_params["forces_key"] del mace_params["stress_key"] @@ -402,18 +410,18 @@ def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value ) if debug_test: - new_cmd = cmd.replace('--', '\n--') - print('calling run train with {name}') - print('command line args:\n', new_cmd) - print('config.yaml:\n', dict_to_yaml_str(yaml_contents), flush=True) + new_cmd = cmd.replace("--", "\n--") + print(f"calling run train with {name}") + print("command line args:\n", new_cmd) + print("config.yaml:\n", dict_to_yaml_str(yaml_contents), flush=True) p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True) assert p.returncode == 0 - if 'heads' in yaml_contents: - headname = list(yaml_contents['heads'].keys())[0] + if "heads" in yaml_contents: + headname = list(yaml_contents["heads"].keys())[0] else: - headname = 'Default' + headname = "Default" calc = MACECalculator( tmp_path / "MACE_.model", device="cpu", default_dtype="float64", head=headname @@ -423,9 +431,22 @@ def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value for at in fitting_configs: at.calc = calc Es.append(at.get_potential_energy()) - - print(np.asarray(Es)) - print(expected_value) - print(type(np.asarray(Es))) - print(type(expected_value)) - assert np.allclose(np.asarray(Es), expected_value) \ No newline at end of file + + if debug_test: + return Es + + assert np.allclose(np.asarray(Es), expected_value) + return 0 + + +# for creating values +def make_output(): + outputs = {} + for yaml_contents, name, expected_value in _trial_yamls_and_and_expected: + if name[0] not in outputs: + outputs[name[0]] = {} + expected = test_key_specification_methods( + Path("."), yaml_contents, name, expected_value, debug_test=False + ) + outputs[name[0]][name[1]] = expected + print(outputs) From 70341b1162959b68384c96023b031b7b184ab06b Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 9 Sep 2024 16:17:14 +0200 Subject: [PATCH 15/27] fix average e0s method --- mace/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index ee4678b8..7b2f6330 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -288,7 +288,7 @@ def compute_average_E0s( A = np.zeros((len_train, len_zs)) B = np.zeros(len_train) for i in range(len_train): - B[i] = collections_train[i].energy + B[i] = collections_train[i].properties["energy"] for j, z in enumerate(z_table.zs): A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) try: From 92d39ebce3e8291cdeeffd52bd9be8e8a9c233c0 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 9 Sep 2024 16:20:17 +0200 Subject: [PATCH 16/27] added missing charges and dipoles weights --- mace/data/atomic_data.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 07ce39c2..d6fdd3f4 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -57,6 +57,8 @@ def __init__( forces_weight: Optional[torch.Tensor], # [,] stress_weight: Optional[torch.Tensor], # [,] virials_weight: Optional[torch.Tensor], # [,] + dipole_weight: Optional[torch.Tensor], # [,] + charges_weight: Optional[torch.Tensor], # [,] forces: Optional[torch.Tensor], # [n_nodes, 3] energy: Optional[torch.Tensor], # [, ] stress: Optional[torch.Tensor], # [1,3,3] @@ -78,6 +80,8 @@ def __init__( assert forces_weight is None or len(forces_weight.shape) == 0 assert stress_weight is None or len(stress_weight.shape) == 0 assert virials_weight is None or len(virials_weight.shape) == 0 + assert dipole_weight is None or dipole_weight.shape == (1,3), dipole_weight + assert charges_weight is None or len(charges_weight.shape) == 0 assert cell is None or cell.shape == (3, 3) assert forces is None or forces.shape == (num_nodes, 3) assert energy is None or len(energy.shape) == 0 @@ -100,6 +104,8 @@ def __init__( "forces_weight": forces_weight, "stress_weight": stress_weight, "virials_weight": virials_weight, + "dipole_weight": dipole_weight, + "charges_weight": charges_weight, "forces": forces, "energy": energy, "stress": stress, @@ -181,6 +187,26 @@ def from_config( else torch.tensor(1.0, dtype=torch.get_default_dtype()) ) + dipole_weight = ( + torch.tensor( + config.property_weights.get("dipole"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("dipole") is not None + else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) + ) + if len(dipole_weight.shape) == 0: + dipole_weight = dipole_weight * torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) + elif len(dipole_weight.shape) == 1: + dipole_weight = dipole_weight.unsqueeze(0) + + charges_weight = ( + torch.tensor( + config.property_weights.get("charges"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("charges") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + forces = ( torch.tensor( config.properties.get("forces"), dtype=torch.get_default_dtype() @@ -241,6 +267,8 @@ def from_config( forces_weight=forces_weight, stress_weight=stress_weight, virials_weight=virials_weight, + dipole_weight=dipole_weight, + charges_weight=charges_weight, forces=forces, energy=energy, stress=stress, From e9e2779c3006dba52229e6c6087ca6e2b9158a36 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 9 Sep 2024 16:22:42 +0200 Subject: [PATCH 17/27] linting --- mace/data/atomic_data.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index d6fdd3f4..92fa6b22 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -80,7 +80,7 @@ def __init__( assert forces_weight is None or len(forces_weight.shape) == 0 assert stress_weight is None or len(stress_weight.shape) == 0 assert virials_weight is None or len(virials_weight.shape) == 0 - assert dipole_weight is None or dipole_weight.shape == (1,3), dipole_weight + assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight assert charges_weight is None or len(charges_weight.shape) == 0 assert cell is None or cell.shape == (3, 3) assert forces is None or forces.shape == (num_nodes, 3) @@ -195,7 +195,9 @@ def from_config( else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) ) if len(dipole_weight.shape) == 0: - dipole_weight = dipole_weight * torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) + dipole_weight = dipole_weight * torch.tensor( + [[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype() + ) elif len(dipole_weight.shape) == 1: dipole_weight = dipole_weight.unsqueeze(0) From c0b65e209733cd47e038ad6a7e8d95c3e1fe3302 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Mon, 9 Sep 2024 16:23:08 +0200 Subject: [PATCH 18/27] moved keyspec construction into run_train --- mace/cli/run_train.py | 16 +++++++++++++++- mace/tools/multihead_tools.py | 11 ++--------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index e7271374..56c0167b 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -157,12 +157,26 @@ def run(args: argparse.Namespace) -> None: if args.heads is not None: args.heads = ast.literal_eval(args.heads) + for _, head_dict in args.heads.items(): + # priority is global args < head property_key values < head info_keys+arrays_keys + head_keyspec = deepcopy(args.key_specification) + update_keyspec_from_kwargs(head_keyspec, head_dict) + head_keyspec.update( + info_keys=head_dict.get("info_keys", {}), + arrays_keys=head_dict.get("arrays_keys", {}), + ) + head_dict["key_specification"] = head_keyspec else: args.heads = prepare_default_head(args) logging.info("===========LOADING INPUT DATA===========") heads = list(args.heads.keys()) logging.info(f"Using heads: {heads}") + logging.info("Using the key specifications to parse data:") + for name, head_dict in args.heads.items(): + head_keyspec = head_dict["key_specification"] + logging.info(f"{name}: {head_keyspec}") + head_configs: List[HeadConfig] = [] for head, head_args in args.heads.items(): logging.info(f"============= Processing head {head} ===========") @@ -647,7 +661,7 @@ def run(args: argparse.Namespace) -> None: folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) for test_name, test_set in test_sets.items(): - print(test_name) + logging.info("test_name", test_name) test_sampler = None if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler( diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 369d1e10..16f6e491 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -3,13 +3,12 @@ import logging import os import urllib.request -from copy import deepcopy from typing import Any, Dict, List, Optional, Union import torch from mace.cli.fine_tuning_select import select_samples -from mace.data import KeySpecification, update_keyspec_from_kwargs +from mace.data import KeySpecification from mace.tools.scripts_utils import ( SubsetCollection, dict_to_namespace, @@ -44,12 +43,6 @@ class HeadConfig: def dict_head_to_dataclass( head: Dict[str, Any], head_name: str, args: argparse.Namespace ) -> HeadConfig: - # priority is global args < head property_key values < head info_keys+arrays_keys - head_keyspec = deepcopy(args.key_specification) - update_keyspec_from_kwargs(head_keyspec, head) - head_keyspec.update( - info_keys=head.get("info_keys", {}), arrays_keys=head.get("arrays_keys", {}) - ) # parser+head args that have no defaults but are required if (args.train_file is None) and (head.get("train_file", None) is None): raise ValueError( @@ -72,7 +65,7 @@ def dict_head_to_dataclass( mean=head.get("mean", args.mean), std=head.get("std", args.std), avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), - key_specification=head_keyspec, + key_specification=head["key_specification"], keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), ) From 3b0c34fb9f7d955275450480b0b47311e88e66c6 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Wed, 18 Sep 2024 15:47:26 +0100 Subject: [PATCH 19/27] pass copies to neighborhood --- mace/data/atomic_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 92fa6b22..3c3f10c2 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -18,7 +18,7 @@ from .neighborhood import get_neighborhood from .utils import Configuration - +from copy import deepcopy class AtomicData(torch_geometric.data.Data): num_graphs: torch.Tensor @@ -127,7 +127,7 @@ def from_config( if heads is None: heads = ["Default"] edge_index, shifts, unit_shifts, cell = get_neighborhood( - positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell + positions=config.positions, cutoff=cutoff, pbc=deepcopy(config.pbc), cell=deepcopy(config.cell) ) indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) one_hot = to_one_hot( From b78d1ec8b7ad1a9c4f58565ea4641e80af5932c7 Mon Sep 17 00:00:00 2001 From: "W.J. Baldwin" Date: Tue, 22 Oct 2024 14:04:58 +0100 Subject: [PATCH 20/27] convience function for logging dataset stats --- mace/tools/scripts_utils.py | 61 +++++++++++++++---------------------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 17c6df5e..976caf75 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -32,6 +32,25 @@ class SubsetCollection: tests: List[Tuple[str, data.Configurations]] +def log_dataset_contents(dataset, dataset_name, keyspec): + all_property_names = list(keyspec.info_keys.keys()) + list(keyspec.arrays_keys.keys()) + log_string = f'{dataset_name} [' + for prop_name in all_property_names: + if prop_name == 'dipole': + log_string += f"{prop_name} components: {int(np.sum([np.sum(config.property_weights[prop_name]) for config in dataset]))}, " + else: + log_string += f"{prop_name}: {int(np.sum([config.property_weights[prop_name] for config in dataset]))}, " + """ except ValueError: + config = dataset[0] + print(config.property_weights) + print(prop_name) + print(config.property_weights[prop_name]) + print(config) + exit(0) """ + log_string = log_string[:-2] + ']' + logging.info(log_string) + + def get_dataset_from_xyz( work_dir: str, train_path: str, @@ -53,15 +72,8 @@ def get_dataset_from_xyz( keep_isolated_atoms=keep_isolated_atoms, head_name=head_name, ) - num_energies = int( - np.sum([config.property_weights["energy"] for config in all_train_configs]) - ) - num_forces = int( - np.sum([config.property_weights["forces"] for config in all_train_configs]) - ) - logging.info( - f"Training set [{len(all_train_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{train_path}'" - ) + log_dataset_contents(all_train_configs, 'Training set', key_specification) + if valid_path is not None: _, valid_configs = data.load_from_xyz( file_path=valid_path, @@ -70,29 +82,14 @@ def get_dataset_from_xyz( extract_atomic_energies=False, head_name=head_name, ) - num_energies = int( - np.sum([config.property_weights["energy"] for config in valid_configs]) - ) - num_forces = int( - np.sum([config.property_weights["forces"] for config in valid_configs]) - ) - logging.info( - f"Validation set [{len(valid_configs)} configs, {num_energies} energy, {num_forces} forces] loaded from '{valid_path}'" - ) + log_dataset_contents(valid_configs, 'Validation set', key_specification) train_configs = all_train_configs else: train_configs, valid_configs = data.random_train_valid_split( all_train_configs, valid_fraction, seed, work_dir ) - num_energies = int( - np.sum([config.property_weights["energy"] for config in valid_configs]) - ) - num_forces = int( - np.sum([config.property_weights["forces"] for config in valid_configs]) - ) - logging.info( - f"Validation set contains {len(valid_configs)} configs, [{num_energies} energy, {num_forces} forces]" - ) + log_dataset_contents(train_configs, 'Random Split Training set', key_specification) + log_dataset_contents(valid_configs, 'Random Split Validation set', key_specification) test_configs = [] if test_path is not None: _, all_test_configs = data.load_from_xyz( @@ -108,15 +105,7 @@ def get_dataset_from_xyz( f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" ) for name, tmp_configs in test_configs: - num_energies = int( - np.sum([config.property_weights["energy"] for config in tmp_configs]) - ) - num_forces = int( - np.sum([config.property_weights["forces"] for config in tmp_configs]) - ) - logging.info( - f"{name}: {len(tmp_configs)} configs, {num_energies} energy, {num_forces} forces" - ) + log_dataset_contents(tmp_configs, f'Test set {name}', key_specification) return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), From eed2a4135ab7720b97ad7da718580301480241b7 Mon Sep 17 00:00:00 2001 From: "W.J. Baldwin" Date: Tue, 22 Oct 2024 14:14:30 +0100 Subject: [PATCH 21/27] formatting --- mace/data/atomic_data.py | 8 ++++++-- mace/tools/scripts_utils.py | 33 ++++++++++++++++----------------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 3c3f10c2..b301a143 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -4,6 +4,7 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +from copy import deepcopy from typing import Optional, Sequence import torch.utils.data @@ -18,7 +19,7 @@ from .neighborhood import get_neighborhood from .utils import Configuration -from copy import deepcopy + class AtomicData(torch_geometric.data.Data): num_graphs: torch.Tensor @@ -127,7 +128,10 @@ def from_config( if heads is None: heads = ["Default"] edge_index, shifts, unit_shifts, cell = get_neighborhood( - positions=config.positions, cutoff=cutoff, pbc=deepcopy(config.pbc), cell=deepcopy(config.cell) + positions=config.positions, + cutoff=cutoff, + pbc=deepcopy(config.pbc), + cell=deepcopy(config.cell), ) indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) one_hot = to_one_hot( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 976caf75..68e618be 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -33,21 +33,16 @@ class SubsetCollection: def log_dataset_contents(dataset, dataset_name, keyspec): - all_property_names = list(keyspec.info_keys.keys()) + list(keyspec.arrays_keys.keys()) - log_string = f'{dataset_name} [' + all_property_names = list(keyspec.info_keys.keys()) + list( + keyspec.arrays_keys.keys() + ) + log_string = f"{dataset_name} [" for prop_name in all_property_names: - if prop_name == 'dipole': + if prop_name == "dipole": log_string += f"{prop_name} components: {int(np.sum([np.sum(config.property_weights[prop_name]) for config in dataset]))}, " else: log_string += f"{prop_name}: {int(np.sum([config.property_weights[prop_name] for config in dataset]))}, " - """ except ValueError: - config = dataset[0] - print(config.property_weights) - print(prop_name) - print(config.property_weights[prop_name]) - print(config) - exit(0) """ - log_string = log_string[:-2] + ']' + log_string = log_string[:-2] + "]" logging.info(log_string) @@ -72,8 +67,8 @@ def get_dataset_from_xyz( keep_isolated_atoms=keep_isolated_atoms, head_name=head_name, ) - log_dataset_contents(all_train_configs, 'Training set', key_specification) - + log_dataset_contents(all_train_configs, "Training set", key_specification) + if valid_path is not None: _, valid_configs = data.load_from_xyz( file_path=valid_path, @@ -82,14 +77,18 @@ def get_dataset_from_xyz( extract_atomic_energies=False, head_name=head_name, ) - log_dataset_contents(valid_configs, 'Validation set', key_specification) + log_dataset_contents(valid_configs, "Validation set", key_specification) train_configs = all_train_configs else: train_configs, valid_configs = data.random_train_valid_split( all_train_configs, valid_fraction, seed, work_dir ) - log_dataset_contents(train_configs, 'Random Split Training set', key_specification) - log_dataset_contents(valid_configs, 'Random Split Validation set', key_specification) + log_dataset_contents( + train_configs, "Random Split Training set", key_specification + ) + log_dataset_contents( + valid_configs, "Random Split Validation set", key_specification + ) test_configs = [] if test_path is not None: _, all_test_configs = data.load_from_xyz( @@ -105,7 +104,7 @@ def get_dataset_from_xyz( f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" ) for name, tmp_configs in test_configs: - log_dataset_contents(tmp_configs, f'Test set {name}', key_specification) + log_dataset_contents(tmp_configs, f"Test set {name}", key_specification) return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), From b6bf6c0c9abca365cb6aaee1f17a8a926587a6e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rokas=20Elijo=C5=A1ius?= Date: Thu, 24 Oct 2024 11:43:59 +0100 Subject: [PATCH 22/27] fix type hint --- mace/data/neighborhood.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/data/neighborhood.py b/mace/data/neighborhood.py index 21296fa6..03728969 100644 --- a/mace/data/neighborhood.py +++ b/mace/data/neighborhood.py @@ -10,7 +10,7 @@ def get_neighborhood( pbc: Optional[Tuple[bool, bool, bool]] = None, cell: Optional[np.ndarray] = None, # [3, 3] true_self_interaction=False, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if pbc is None: pbc = (False, False, False) From 5b44fcebb5ebcaf766e4715205717acdabcf7314 Mon Sep 17 00:00:00 2001 From: "W.J. Baldwin" Date: Tue, 29 Oct 2024 13:41:45 +0000 Subject: [PATCH 23/27] minor fixes from review --- mace/data/atomic_data.py | 2 ++ mace/data/utils.py | 4 ++-- mace/tools/scripts_utils.py | 2 +- tests/test_run_train_allkeys.py | 29 ++++++++++------------------- 4 files changed, 15 insertions(+), 22 deletions(-) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index b301a143..14dd39df 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -43,6 +43,8 @@ class AtomicData(torch_geometric.data.Data): forces_weight: torch.Tensor stress_weight: torch.Tensor virials_weight: torch.Tensor + dipole_weight: torch.Tensor + charges_weight: torch.Tensor def __init__( self, diff --git a/mace/data/utils.py b/mace/data/utils.py index 7b2f6330..5659adc7 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -40,7 +40,7 @@ def update( return self -def update_keyspec_from_kwargs(keyspec, keydict) -> KeySpecification: +def update_keyspec_from_kwargs(keyspec: KeySpecification, keydict:Dict[str, str]) -> KeySpecification: # convert command line style property_key arguments into a keyspec infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] arrays = ["forces_key", "charges_key"] @@ -150,7 +150,7 @@ def config_from_atoms( properties = {} property_weights = {} for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): - property_weights[name] = atoms.info.get("config_" + name + "_weight", 1.0) + property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0) for name, atoms_key in key_specification.info_keys.items(): properties[name] = atoms.info.get(atoms_key, None) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 68e618be..7b907cae 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -32,7 +32,7 @@ class SubsetCollection: tests: List[Tuple[str, data.Configurations]] -def log_dataset_contents(dataset, dataset_name, keyspec): +def log_dataset_contents(dataset: data.Configurations, dataset_name: str, keyspec: KeySpecification) -> None: all_property_names = list(keyspec.info_keys.keys()) + list( keyspec.arrays_keys.keys() ) diff --git a/tests/test_run_train_allkeys.py b/tests/test_run_train_allkeys.py index a6e76082..5807defe 100644 --- a/tests/test_run_train_allkeys.py +++ b/tests/test_run_train_allkeys.py @@ -32,11 +32,20 @@ "amsgrad": None, "device": "cpu", "seed": 5, - "loss": "stress", + "loss": "weighted", "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "interaction_first": "RealAgnosticResidualInteractionBlock", + "batch_size": 1, + "valid_batch_size": 1, + "num_samples_pt": 50, + "subselect_pt": "random", "eval_interval": 2, + "num_radial_basis": 10, + "hidden_irreps": "128x0e", + "r_max": 6.0, + "default_dtype": "float64", } @@ -324,9 +333,7 @@ def trial_yamls_and_and_expected(): list_of_all = [] for key, value in all_arg_sets.items(): - print(key) for key2, value2 in value.items(): - print(" ", key2) list_of_all.append( (value2, (key, key2), np.asarray(all_expected_outputs[key][key2])) ) @@ -364,16 +371,6 @@ def test_key_specification_methods( mace_params["valid_fraction"] = 0.1 mace_params["checkpoints_dir"] = str(tmp_path) mace_params["model_dir"] = str(tmp_path) - mace_params["loss"] = "weighted" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["batch_size"] = 1 - mace_params["valid_batch_size"] = 1 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" mace_params["train_file"] = "fit_multihead_dft.xyz" mace_params["E0s"] = "{1:0.0,8:1.0}" mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz" @@ -409,12 +406,6 @@ def test_key_specification_methods( ) ) - if debug_test: - new_cmd = cmd.replace("--", "\n--") - print(f"calling run train with {name}") - print("command line args:\n", new_cmd) - print("config.yaml:\n", dict_to_yaml_str(yaml_contents), flush=True) - p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True) assert p.returncode == 0 From c00410c1091be2cdc97e8ac35c29feeabdd904fa Mon Sep 17 00:00:00 2001 From: "W.J. Baldwin" Date: Tue, 29 Oct 2024 14:59:43 +0000 Subject: [PATCH 24/27] fixes for new tests and linting --- tests/test_modules.py | 71 +++++++++++++++++++++------------ tests/test_preprocess.py | 16 ++++---- tests/test_run_train.py | 10 ++++- tests/test_run_train_allkeys.py | 10 ++--- 4 files changed, 64 insertions(+), 43 deletions(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 9775ca6d..6afcccfb 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -30,15 +30,22 @@ def _config(): [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, - stress=np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "stress": np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "stress": 1.0, + }, ) @@ -58,14 +65,20 @@ def _config1(): [0.0, 1.0, 0.0], ] ), - forces=np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - energy=-1.5, + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, head="DFT", ) @@ -81,14 +94,20 @@ def _config2(): [0.1, 1.1, 0.1], ] ), - forces=np.array( - [ - [0.1, -1.2, 0.1], - [1.1, 0.3, 0.1], - [0.1, 1.2, 0.4], - ] - ), - energy=-1.4, + properties={ + "forces": np.array( + [ + [0.1, -1.2, 0.1], + [1.1, 0.3, 0.1], + [0.1, 1.2, 0.4], + ] + ), + "energy": -1.4, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, head="MP2", ) @@ -246,4 +265,4 @@ def test_compute_statistics(data_loader, atomic_energies): assert np.all(mean != 0) assert np.all(std > 0) assert mean[0] != mean[1] - assert std[0] != std[1] \ No newline at end of file + assert std[0] != std[1] diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index e0258bd4..c1cb28c4 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -110,8 +110,8 @@ def test_preprocess_data(tmp_path, sample_configs): config = f["config_batch_0"]["config_0"] assert "atomic_numbers" in config assert "positions" in config - assert "energy" in config - assert "forces" in config + assert "energy" in config["properties"] + assert "forces" in config["properties"] original_energies = [ config.info["REF_energy"] @@ -134,19 +134,19 @@ def test_preprocess_data(tmp_path, sample_configs): config = batch[config_key] assert "atomic_numbers" in config assert "positions" in config - assert "energy" in config - assert "forces" in config + assert "energy" in config["properties"] + assert "forces" in config["properties"] - h5_energies.append(config["energy"][()]) - h5_forces.append(config["forces"][()]) + h5_energies.append(config["properties"]["energy"][()]) + h5_forces.append(config["properties"]["forces"][()]) for val_file in val_files: with h5py.File(val_file, "r") as f: for _, batch in f.items(): for config_key in batch.keys(): config = batch[config_key] - h5_energies.append(config["energy"][()]) - h5_forces.append(config["forces"][()]) + h5_energies.append(config["properties"]["energy"][()]) + h5_forces.append(config["properties"]["forces"][()]) print("Original energies", original_energies) print("H5 energies", h5_energies) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 42f4dd7d..ae4f6041 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -376,7 +376,10 @@ def test_run_train_multihead(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64", head="CCD" + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head="CCD", ) Es = [] @@ -718,7 +721,10 @@ def test_run_train_multihead_replay_custum_finetuning( # Load and test the finetuned model calc = MACECalculator( - model_paths=tmp_path / "finetuned.model", device="cpu", default_dtype="float64" + model_paths=tmp_path / "finetuned.model", + device="cpu", + default_dtype="float64", + head="pt_head", ) Es = [] diff --git a/tests/test_run_train_allkeys.py b/tests/test_run_train_allkeys.py index 5807defe..275b13bc 100644 --- a/tests/test_run_train_allkeys.py +++ b/tests/test_run_train_allkeys.py @@ -22,8 +22,6 @@ "stress_weight": 1.0, "model": "MACE", "hidden_irreps": "128x0e", - "r_max": 3.5, - "batch_size": 5, "max_num_epochs": 10, "swa": None, "start_swa": 5, @@ -43,7 +41,6 @@ "subselect_pt": "random", "eval_interval": 2, "num_radial_basis": 10, - "hidden_irreps": "128x0e", "r_max": 6.0, "default_dtype": "float64", } @@ -359,7 +356,7 @@ def dict_to_yaml_str(data, indent=0): "yaml_contents, name, expected_value", _trial_yamls_and_and_expected ) def test_key_specification_methods( - tmp_path, yaml_contents, name, expected_value, debug_test=False + tmp_path, yaml_contents, name, expected_value ): fitting_configs = configs_numbered_keys() @@ -423,11 +420,10 @@ def test_key_specification_methods( at.calc = calc Es.append(at.get_potential_energy()) - if debug_test: - return Es + print(name) + print("Es", Es) assert np.allclose(np.asarray(Es), expected_value) - return 0 # for creating values From 908acd125e5a8eef80e316950c52a081c6d67ab9 Mon Sep 17 00:00:00 2001 From: "W.J. Baldwin" Date: Tue, 29 Oct 2024 15:00:37 +0000 Subject: [PATCH 25/27] head key in preprocessor --- mace/tools/arg_parser.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index c4d3a9ec..f64fe41b 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -844,6 +844,12 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: type=int, default=123, ) + parser.add_argument( + "--head_key", + help="Key of head in training xyz", + type=str, + default="head", + ) return parser From 808f79456420acaa4d51eadb84fd4b4999af3bb0 Mon Sep 17 00:00:00 2001 From: "W.J. Baldwin" Date: Tue, 29 Oct 2024 15:00:56 +0000 Subject: [PATCH 26/27] formatting --- mace/cli/eval_configs.py | 4 ++-- mace/data/utils.py | 4 +++- mace/tools/scripts_utils.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index f44f7515..86470f8d 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -58,7 +58,7 @@ def parse_args() -> argparse.Namespace: help="Model head used for evaluation", type=str, required=False, - default=None + default=None, ) return parser.parse_args() @@ -94,7 +94,7 @@ def run(args: argparse.Namespace) -> None: heads = model.heads except AttributeError: heads = None - + data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( diff --git a/mace/data/utils.py b/mace/data/utils.py index 5659adc7..a552f1a8 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -40,7 +40,9 @@ def update( return self -def update_keyspec_from_kwargs(keyspec: KeySpecification, keydict:Dict[str, str]) -> KeySpecification: +def update_keyspec_from_kwargs( + keyspec: KeySpecification, keydict: Dict[str, str] +) -> KeySpecification: # convert command line style property_key arguments into a keyspec infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] arrays = ["forces_key", "charges_key"] diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index d517f066..4ac51b26 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -33,7 +33,9 @@ class SubsetCollection: tests: List[Tuple[str, data.Configurations]] -def log_dataset_contents(dataset: data.Configurations, dataset_name: str, keyspec: KeySpecification) -> None: +def log_dataset_contents( + dataset: data.Configurations, dataset_name: str, keyspec: KeySpecification +) -> None: all_property_names = list(keyspec.info_keys.keys()) + list( keyspec.arrays_keys.keys() ) From 7a19ed6e9defc9bd66fd04cc0e9653592eb38376 Mon Sep 17 00:00:00 2001 From: Will Baldwin Date: Tue, 29 Oct 2024 17:42:06 +0000 Subject: [PATCH 27/27] new calculator syntax in test_run_train --- tests/test_run_train.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index bb599a98..80d7eaa7 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -686,15 +686,19 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): p = subprocess.run(cmd.split(), env=run_env, check=True) assert p.returncode == 0 - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - Es = [] for at in fitting_configs: + config_head = at.info.get("head", "MP2") + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, + ) at.calc = calc Es.append(at.get_potential_energy()) + print("Es", Es) # from a run on 20/08/2024 on commit ref_Es = [