diff --git a/.gitignore b/.gitignore index 3817d9f3..296776e4 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ dist/ *.xyz /checkpoints *.model + +.benchmarks diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index 51f27ab0..3ebddbed 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -42,7 +42,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: checkpoint_url = ( urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") + if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") else model ) @@ -106,7 +106,7 @@ def mace_mp( MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). """ try: - if model in (None, "small", "medium", "large") or str(model).startswith( + if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith( "https:" ): model_path = download_mace_mp_checkpoint(model) diff --git a/mace/data/utils.py b/mace/data/utils.py index bb8e5448..59b868ed 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -265,7 +265,6 @@ def load_from_xyz( atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): - atoms.info[head_key] = head_name isolated_atom_config = ( len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" ) @@ -288,6 +287,9 @@ def load_from_xyz( if not keep_isolated_atoms: atoms_list = atoms_without_iso_atoms + for atoms in atoms_list: + atoms.info[head_key] = head_name + configs = config_from_atoms_list( atoms_list, config_type_weights=config_type_weights, diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index cbce368a..e4e90a10 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -21,7 +21,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--config", type=str, is_config_file=True, - help="config file to agregate options", + help="config file to aggregate options", ) except ImportError: parser = argparse.ArgumentParser( @@ -727,9 +727,24 @@ def build_default_arg_parser() -> argparse.ArgumentParser: def build_preprocess_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to aggregate options", + ) + except ImportError: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( "--train_file", help="Training set h5 file", diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index e42a74f8..31e837df 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -6,7 +6,7 @@ import logging from contextlib import contextmanager -from typing import Dict +from typing import Dict, Union import numpy as np import torch @@ -129,13 +129,18 @@ def init_wandb(project: str, entity: str, name: str, config: dict, directory: st @contextmanager -def default_dtype(dtype: torch.dtype): +def default_dtype(dtype: Union[torch.dtype, str]): """Context manager for configuring the default_dtype used by torch Args: - dtype (torch.dtype): the default dtype to use within this context manager + dtype (torch.dtype|str): the default dtype to use within this context manager """ init = torch.get_default_dtype() - torch.set_default_dtype(dtype) + if isinstance(dtype, str): + set_default_dtype(dtype) + else: + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(init) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 00000000..78b04ccd --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,122 @@ +import os +from typing import Optional + +import pandas as pd +import json +import pytest +import torch +from ase import build + +from mace import data +from mace.calculators.foundations_models import mace_mp +from mace.tools import AtomicNumberTable, torch_geometric, torch_tools + + +def is_mace_full_bench(): + return os.environ.get("MACE_FULL_BENCH", "0") == "1" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8) +@pytest.mark.parametrize("size", (3, 5, 7, 9)) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +@pytest.mark.parametrize("compile_mode", [None, "default"]) +def test_inference( + benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda" +): + if not is_mace_full_bench() and compile_mode is not None: + pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute") + + with torch_tools.default_dtype(dtype): + model = load_mace_mp_medium(dtype, compile_mode, device) + batch = create_batch(size, model, device) + log_bench_info(benchmark, dtype, compile_mode, batch) + + def func(): + torch.cuda.synchronize() + model(batch, training=compile_mode is not None, compute_force=True) + + torch.cuda.empty_cache() + benchmark(func) + + +def load_mace_mp_medium(dtype, compile_mode, device): + calc = mace_mp( + model="medium", + default_dtype=dtype, + device=device, + compile_mode=compile_mode, + fullgraph=False, + ) + model = calc.models[0].to(device) + return model + + +def create_batch(size: int, model: torch.nn.Module, device: str) -> dict: + cutoff = model.r_max.item() + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms = atoms.repeat((size, size, size)) + config = data.config_from_atoms(atoms) + dataset = [data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch.to(device) + return batch.to_dict() + + +def log_bench_info(benchmark, dtype, compile_mode, batch): + benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0]) + benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1]) + benchmark.extra_info["dtype"] = dtype + benchmark.extra_info["is_compiled"] = compile_mode is not None + benchmark.extra_info["device_name"] = torch.cuda.get_device_name() + + +def read_bench_results(files: list[str]) -> pd.DataFrame: + def read(file): + with open(file, "r") as f: + data = json.load(f) + + records = [] + for bench in data["benchmarks"]: + record = {**bench["extra_info"], **bench["stats"]} + records.append(record) + + df = pd.DataFrame(records) + df["ns/day (1 fs/step)"] = 0.086400 / df["median"] + df["Steps per day"] = df["ops"] * 86400 + columns = [ + "num_atoms", + "num_edges", + "dtype", + "is_compiled", + "device_name", + "median", + "Steps per day", + "ns/day (1 fs/step)", + ] + return df[columns] + + return pd.concat([read(f) for f in files]) + + +if __name__ == "__main__": + # Print to stdout a csv of the benchmark metrics + import subprocess + + result = subprocess.run( + ["pytest-benchmark", "list"], capture_output=True, text=True + ) + + if result.returncode != 0: + raise RuntimeError(f"Command failed with return code {result.returncode}") + + files = result.stdout.strip().split("\n") + df = read_bench_results(files) + print(df.to_csv(index=False)) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index e0258bd4..1f3068ba 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -6,6 +6,7 @@ import ase.io import numpy as np import pytest +import yaml from ase.atoms import Atoms pytest_mace_dir = Path(__file__).parent.parent @@ -164,3 +165,42 @@ def test_preprocess_data(tmp_path, sample_configs): np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) print("All checks passed successfully!") + + +def test_preprocess_config(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": str(tmp_path / "sample.xyz"), + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": str(tmp_path / "preprocessed_"), + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + yaml.dump(preprocess_params, file) + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + "--config" + + " " + + str(filename) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0