From f764be4582de56ed6eda68dc08a74f32870e23c2 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 21 Oct 2024 04:40:01 -0600 Subject: [PATCH 01/12] Add mace performance benchmark --- .gitignore | 2 ++ mace/tools/torch_tools.py | 13 ++++--- tests/test_benchmark.py | 76 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 tests/test_benchmark.py diff --git a/.gitignore b/.gitignore index 3817d9f3..296776e4 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ dist/ *.xyz /checkpoints *.model + +.benchmarks diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index e42a74f8..31e837df 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -6,7 +6,7 @@ import logging from contextlib import contextmanager -from typing import Dict +from typing import Dict, Union import numpy as np import torch @@ -129,13 +129,18 @@ def init_wandb(project: str, entity: str, name: str, config: dict, directory: st @contextmanager -def default_dtype(dtype: torch.dtype): +def default_dtype(dtype: Union[torch.dtype, str]): """Context manager for configuring the default_dtype used by torch Args: - dtype (torch.dtype): the default dtype to use within this context manager + dtype (torch.dtype|str): the default dtype to use within this context manager """ init = torch.get_default_dtype() - torch.set_default_dtype(dtype) + if isinstance(dtype, str): + set_default_dtype(dtype) + else: + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(init) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 00000000..bbbcdf36 --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,76 @@ +import os +from typing import Optional + +import pytest +import torch +from ase import build + +from mace import data +from mace.calculators.foundations_models import mace_mp +from mace.tools import AtomicNumberTable, torch_geometric, torch_tools + + +def is_mace_full_bench(): + return os.environ.get("MACE_FULL_BENCH", "0") == "1" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8) +@pytest.mark.parametrize("size", (3, 5, 7, 9)) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +@pytest.mark.parametrize("compile_mode", [None, "default"]) +def test_inference( + benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda" +): + if not is_mace_full_bench() and compile_mode is not None: + pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute") + + with torch_tools.default_dtype(dtype): + model = load_mace_mp_medium(dtype, compile_mode, device) + batch = create_batch(size, model, device) + log_bench_info(benchmark, dtype, compile_mode, batch) + + def func(): + torch.cuda.synchronize() + model(batch, training=compile_mode is not None, compute_force=True) + + torch.cuda.empty_cache() + benchmark(func) + + +def load_mace_mp_medium(dtype, compile_mode, device): + calc = mace_mp( + model="medium", + default_dtype=dtype, + device=device, + compile_mode=compile_mode, + fullgraph=False, + ) + model = calc.models[0].to(device) + return model + + +def create_batch(size: int, model: torch.nn.Module, device: str) -> dict: + cutoff = model.r_max.item() + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms = atoms.repeat((size, size, size)) + config = data.config_from_atoms(atoms) + dataset = [data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch.to(device) + return batch.to_dict() + + +def log_bench_info(benchmark, dtype, compile_mode, batch): + benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0]) + benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1]) + benchmark.extra_info["dtype"] = dtype + benchmark.extra_info["is_compiled"] = compile_mode is not None + benchmark.extra_info["name"] = torch.cuda.get_device_name() From 6d3aa91d054c1423195f4412fc61a49308c14d7e Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 21 Oct 2024 07:19:08 -0600 Subject: [PATCH 02/12] include benchmark metric post-processing --- tests/test_benchmark.py | 48 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index bbbcdf36..78b04ccd 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -1,6 +1,8 @@ import os from typing import Optional +import pandas as pd +import json import pytest import torch from ase import build @@ -73,4 +75,48 @@ def log_bench_info(benchmark, dtype, compile_mode, batch): benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1]) benchmark.extra_info["dtype"] = dtype benchmark.extra_info["is_compiled"] = compile_mode is not None - benchmark.extra_info["name"] = torch.cuda.get_device_name() + benchmark.extra_info["device_name"] = torch.cuda.get_device_name() + + +def read_bench_results(files: list[str]) -> pd.DataFrame: + def read(file): + with open(file, "r") as f: + data = json.load(f) + + records = [] + for bench in data["benchmarks"]: + record = {**bench["extra_info"], **bench["stats"]} + records.append(record) + + df = pd.DataFrame(records) + df["ns/day (1 fs/step)"] = 0.086400 / df["median"] + df["Steps per day"] = df["ops"] * 86400 + columns = [ + "num_atoms", + "num_edges", + "dtype", + "is_compiled", + "device_name", + "median", + "Steps per day", + "ns/day (1 fs/step)", + ] + return df[columns] + + return pd.concat([read(f) for f in files]) + + +if __name__ == "__main__": + # Print to stdout a csv of the benchmark metrics + import subprocess + + result = subprocess.run( + ["pytest-benchmark", "list"], capture_output=True, text=True + ) + + if result.returncode != 0: + raise RuntimeError(f"Command failed with return code {result.returncode}") + + files = result.stdout.strip().split("\n") + df = read_bench_results(files) + print(df.to_csv(index=False)) From 0e5d72b5a72be9724035a0ee97a0b42af96e8e8d Mon Sep 17 00:00:00 2001 From: Thomas Warford Date: Sat, 9 Nov 2024 23:58:41 +0000 Subject: [PATCH 03/12] add head_key=head_name to info dict of atoms regardless of isolated_atom_config_value --- mace/data/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index bb8e5448..59b868ed 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -265,7 +265,6 @@ def load_from_xyz( atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): - atoms.info[head_key] = head_name isolated_atom_config = ( len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" ) @@ -288,6 +287,9 @@ def load_from_xyz( if not keep_isolated_atoms: atoms_list = atoms_without_iso_atoms + for atoms in atoms_list: + atoms.info[head_key] = head_name + configs = config_from_atoms_list( atoms_list, config_type_weights=config_type_weights, From 0d5e22243414378347a39f2fa1e342be5df6db36 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Wed, 30 Oct 2024 17:13:27 +0000 Subject: [PATCH 04/12] Add config option for pre-processing --- mace/tools/arg_parser.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index cb4f8ac5..a864f0a3 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -714,9 +714,24 @@ def build_default_arg_parser() -> argparse.ArgumentParser: def build_preprocess_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to agregate options", + ) + except ImportError: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( "--train_file", help="Training set h5 file", From 293fe6075caa36753a46e1250f7fea15e1289cde Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Wed, 30 Oct 2024 17:53:56 +0000 Subject: [PATCH 05/12] Test preprocessing config --- tests/test_preprocess.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index e0258bd4..1f3068ba 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -6,6 +6,7 @@ import ase.io import numpy as np import pytest +import yaml from ase.atoms import Atoms pytest_mace_dir = Path(__file__).parent.parent @@ -164,3 +165,42 @@ def test_preprocess_data(tmp_path, sample_configs): np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) print("All checks passed successfully!") + + +def test_preprocess_config(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": str(tmp_path / "sample.xyz"), + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": str(tmp_path / "preprocessed_"), + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + yaml.dump(preprocess_params, file) + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + "--config" + + " " + + str(filename) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 From 5e1c4dd6ffe747f49a7d11a9846b925f50de8ab6 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Wed, 30 Oct 2024 18:06:09 +0000 Subject: [PATCH 06/12] Fix typos --- mace/tools/arg_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index a864f0a3..a49a722e 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -21,7 +21,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--config", type=str, is_config_file=True, - help="config file to agregate options", + help="config file to aggregate options", ) except ImportError: parser = argparse.ArgumentParser( @@ -725,7 +725,7 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: "--config", type=str, is_config_file=True, - help="config file to agregate options", + help="config file to aggregate options", ) except ImportError: parser = argparse.ArgumentParser( From d76ec18c7bbde3f02d6ba6469f038359c7529f7e Mon Sep 17 00:00:00 2001 From: Alin Marin Elena Date: Thu, 28 Nov 2024 11:32:29 +0000 Subject: [PATCH 07/12] now actually download the new models --- mace/calculators/foundations_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 51f27ab0..32f26848 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -42,7 +42,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: checkpoint_url = ( urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") + if model in (None, "small", "medium", "large", "small", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") else model ) @@ -106,7 +106,7 @@ def mace_mp( MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). """ try: - if model in (None, "small", "medium", "large") or str(model).startswith( + if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith( "https:" ): model_path = download_mace_mp_checkpoint(model) From 0e759b635477fcc4d1fe851f5fc23f0ee6a235cc Mon Sep 17 00:00:00 2001 From: Alin Marin Elena Date: Thu, 28 Nov 2024 12:13:28 +0000 Subject: [PATCH 08/12] Update mace/calculators/foundations_models.py Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com> --- mace/calculators/foundations_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 32f26848..3ebddbed 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -42,7 +42,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: checkpoint_url = ( urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large", "small", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") + if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") else model ) From ce09434cc6ce581edcbcfc490255d83486172880 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:52:38 +0000 Subject: [PATCH 09/12] change learning rate for multihead ft --- mace/cli/run_train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index eaf8cdbd..ad2b96f5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -160,6 +160,12 @@ def run(args: argparse.Namespace) -> None: args.E0s != "average" ), "average atomic energies cannot be used for multiheads finetuning" # check that the foundation model has a single head, if not, use the first head + args.lr = 0.001 + args.ema = True + args.ema_decay = 0.999 + logging.info( + "Using multiheads finetuning mode, setting learning rate to 0.001 and EMA to True" + ) if hasattr(model_foundation, "heads"): if len(model_foundation.heads) > 1: logging.warning( From efaca2c3fe9afe94dc184b829914222c3f5d13d4 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:18:55 +0000 Subject: [PATCH 10/12] add option to force lr arg for mh_ft --- mace/cli/run_train.py | 10 +++++++--- mace/tools/arg_parser.py | 6 ++++++ tests/test_run_train.py | 10 +++++----- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index ad2b96f5..e8319ac7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -160,9 +160,13 @@ def run(args: argparse.Namespace) -> None: args.E0s != "average" ), "average atomic energies cannot be used for multiheads finetuning" # check that the foundation model has a single head, if not, use the first head - args.lr = 0.001 - args.ema = True - args.ema_decay = 0.999 + if not args.force_mh_ft_lr: + logging.info( + "Multihead finetuning mode, setting learning rate to 0.001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True." + ) + args.lr = 0.001 + args.ema = True + args.ema_decay = 0.999 logging.info( "Using multiheads finetuning mode, setting learning rate to 0.001 and EMA to True" ) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 07e02e49..cbce368a 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -379,6 +379,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=int, default=1000, ) + parser.add_argument( + "--force_mh_ft_lr", + help="Force the multiheaded fine-tuning to use arg_parser lr", + type=str2bool, + default=False, + ) parser.add_argument( "--subselect_pt", help="Method to subselect the configurations of the pretraining set", diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 894f66c8..2b56c10b 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -437,14 +437,10 @@ def test_run_train_foundation(tmp_path, fitting_configs): mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" mace_params["multiheads_finetuning"] = False - print("mace_params", mace_params) - # mace_params["num_samples_pt"] = 50 - # mace_params["subselect_pt"] = "random" - # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) cmd = ( sys.executable @@ -549,6 +545,7 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): mace_params["valid_batch_size"] = 1 mace_params["num_samples_pt"] = 50 mace_params["subselect_pt"] = "random" + mace_params["force_mh_ft_lr"] = True # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -666,6 +663,7 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): mace_params["valid_batch_size"] = 1 mace_params["num_samples_pt"] = 50 mace_params["subselect_pt"] = "random" + mace_params["force_mh_ft_lr"] = True # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -827,6 +825,7 @@ def test_run_train_multihead_replay_custum_finetuning( "pt_train_file": os.path.join(tmp_path, "pretrain.xyz"), "num_samples_pt": 3, "subselect_pt": "random", + "force_mh_ft_lr": True, } cmd = [sys.executable, str(run_train)] @@ -993,6 +992,7 @@ def test_run_train_foundation_multihead_json_cueq(tmp_path, fitting_configs): mace_params["num_samples_pt"] = 50 mace_params["subselect_pt"] = "random" mace_params["enable_cueq"] = True + mace_params["force_mh_ft_lr"] = True # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) From ae2d46105c96b05ee95dafaf8a2e18a4ed5fa451 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:33:57 +0000 Subject: [PATCH 11/12] fix cueq calc descriptors --- mace/calculators/mace.py | 3 ++- tests/test_calculator.py | 48 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index fc88c051..5789d0d3 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -14,6 +14,7 @@ import torch from ase.calculators.calculator import Calculator, all_changes from ase.stress import full_3x3_to_voigt_6_stress +from e3nn import o3 from mace import data from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq @@ -406,7 +407,7 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): batch = self._atoms_to_batch(atoms) descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] - irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] + irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out)) l_max = irreps_out.lmax num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 per_layer_features = [irreps_out.dim for _ in range(num_interactions)] diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 158cad64..9988e20b 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -521,6 +521,54 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model): assert not np.allclose(desc, desc_rotated, atol=1e-6) +def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model): + at = fitting_configs[2].copy() + at_rotated = fitting_configs[2].copy() + at_rotated.rotate(90, "x") + calc = trained_equivariant_model + + desc_invariant = calc.get_descriptors(at, invariants_only=True, enable_cueq=True) + desc_invariant_rotated = calc.get_descriptors( + at_rotated, invariants_only=True, enable_cueq=True + ) + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1, enable_cueq=True + ) + desc_invariant_single_layer_rotated = calc.get_descriptors( + at_rotated, invariants_only=True, num_layers=1, enable_cueq=True + ) + desc = calc.get_descriptors(at, invariants_only=False, enable_cueq=True) + desc_single_layer = calc.get_descriptors( + at, invariants_only=False, num_layers=1, enable_cueq=True + ) + desc_rotated = calc.get_descriptors( + at_rotated, invariants_only=False, enable_cueq=True + ) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1, enable_cueq=True + ) + + assert desc_invariant.shape[0] == 3 + assert desc_invariant.shape[1] == 32 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 + assert desc.shape[0] == 3 + assert desc.shape[1] == 80 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16 * 4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16 * 4 + + np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) + np.testing.assert_allclose( + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose(desc, desc_rotated, atol=1e-6) + + def test_mace_mp(capsys: pytest.CaptureFixture): mp_mace = mace_mp() assert isinstance(mp_mace, MACECalculator) From 44caf8916d0cfe2083eabaea9788f4903477f371 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:02:35 +0000 Subject: [PATCH 12/12] fix calc cueq test --- tests/test_calculator.py | 99 +++++++++++++++++++++++++++++++++------- 1 file changed, 82 insertions(+), 17 deletions(-) diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 9988e20b..6c9e2563 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -184,6 +184,71 @@ def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") +@pytest.fixture(scope="module", name="trained_equivariant_model_cueq") +def trained_model_equivariant_fixture_cueq(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e+16x1o", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True + ) + + @pytest.fixture(scope="module", name="trained_dipole_model") def trained_dipole_fixture(tmp_path_factory, fitting_configs): _mace_params = { @@ -521,31 +586,25 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model): assert not np.allclose(desc, desc_rotated, atol=1e-6) -def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model): +def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model_cueq): at = fitting_configs[2].copy() at_rotated = fitting_configs[2].copy() at_rotated.rotate(90, "x") - calc = trained_equivariant_model + calc = trained_equivariant_model_cueq - desc_invariant = calc.get_descriptors(at, invariants_only=True, enable_cueq=True) - desc_invariant_rotated = calc.get_descriptors( - at_rotated, invariants_only=True, enable_cueq=True - ) + desc_invariant = calc.get_descriptors(at, invariants_only=True) + desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) desc_invariant_single_layer = calc.get_descriptors( - at, invariants_only=True, num_layers=1, enable_cueq=True + at, invariants_only=True, num_layers=1 ) desc_invariant_single_layer_rotated = calc.get_descriptors( - at_rotated, invariants_only=True, num_layers=1, enable_cueq=True - ) - desc = calc.get_descriptors(at, invariants_only=False, enable_cueq=True) - desc_single_layer = calc.get_descriptors( - at, invariants_only=False, num_layers=1, enable_cueq=True - ) - desc_rotated = calc.get_descriptors( - at_rotated, invariants_only=False, enable_cueq=True + at_rotated, invariants_only=True, num_layers=1 ) + desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) + desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) desc_rotated_single_layer = calc.get_descriptors( - at_rotated, invariants_only=False, num_layers=1, enable_cueq=True + at_rotated, invariants_only=False, num_layers=1 ) assert desc_invariant.shape[0] == 3 @@ -566,7 +625,13 @@ def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model): np.testing.assert_allclose( desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 ) - np.testing.assert_allclose(desc, desc_rotated, atol=1e-6) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 + ) + assert not np.allclose(desc, desc_rotated, atol=1e-6) def test_mace_mp(capsys: pytest.CaptureFixture):