Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/ACEsuit/mace into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Dec 5, 2024
2 parents 44caf89 + 3f53730 commit 2b30ec9
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ dist/
*.xyz
/checkpoints
*.model

.benchmarks
4 changes: 2 additions & 2 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:

checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2")
else model
)

Expand Down Expand Up @@ -106,7 +106,7 @@ def mace_mp(
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
if model in (None, "small", "medium", "large") or str(model).startswith(
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith(
"https:"
):
model_path = download_mace_mp_checkpoint(model)
Expand Down
4 changes: 3 additions & 1 deletion mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def load_from_xyz(
atoms_without_iso_atoms = []

for idx, atoms in enumerate(atoms_list):
atoms.info[head_key] = head_name
isolated_atom_config = (
len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom"
)
Expand All @@ -288,6 +287,9 @@ def load_from_xyz(
if not keep_isolated_atoms:
atoms_list = atoms_without_iso_atoms

for atoms in atoms_list:
atoms.info[head_key] = head_name

configs = config_from_atoms_list(
atoms_list,
config_type_weights=config_type_weights,
Expand Down
23 changes: 19 additions & 4 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--config",
type=str,
is_config_file=True,
help="config file to agregate options",
help="config file to aggregate options",
)
except ImportError:
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -727,9 +727,24 @@ def build_default_arg_parser() -> argparse.ArgumentParser:


def build_preprocess_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
try:
import configargparse

parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
"--config",
type=str,
is_config_file=True,
help="config file to aggregate options",
)
except ImportError:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument(
"--train_file",
help="Training set h5 file",
Expand Down
13 changes: 9 additions & 4 deletions mace/tools/torch_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
from contextlib import contextmanager
from typing import Dict
from typing import Dict, Union

import numpy as np
import torch
Expand Down Expand Up @@ -129,13 +129,18 @@ def init_wandb(project: str, entity: str, name: str, config: dict, directory: st


@contextmanager
def default_dtype(dtype: torch.dtype):
def default_dtype(dtype: Union[torch.dtype, str]):
"""Context manager for configuring the default_dtype used by torch
Args:
dtype (torch.dtype): the default dtype to use within this context manager
dtype (torch.dtype|str): the default dtype to use within this context manager
"""
init = torch.get_default_dtype()
torch.set_default_dtype(dtype)
if isinstance(dtype, str):
set_default_dtype(dtype)
else:
torch.set_default_dtype(dtype)

yield

torch.set_default_dtype(init)
122 changes: 122 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from typing import Optional

import pandas as pd
import json
import pytest
import torch
from ase import build

from mace import data
from mace.calculators.foundations_models import mace_mp
from mace.tools import AtomicNumberTable, torch_geometric, torch_tools


def is_mace_full_bench():
return os.environ.get("MACE_FULL_BENCH", "0") == "1"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8)
@pytest.mark.parametrize("size", (3, 5, 7, 9))
@pytest.mark.parametrize("dtype", ["float32", "float64"])
@pytest.mark.parametrize("compile_mode", [None, "default"])
def test_inference(
benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda"
):
if not is_mace_full_bench() and compile_mode is not None:
pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute")

with torch_tools.default_dtype(dtype):
model = load_mace_mp_medium(dtype, compile_mode, device)
batch = create_batch(size, model, device)
log_bench_info(benchmark, dtype, compile_mode, batch)

def func():
torch.cuda.synchronize()
model(batch, training=compile_mode is not None, compute_force=True)

torch.cuda.empty_cache()
benchmark(func)


def load_mace_mp_medium(dtype, compile_mode, device):
calc = mace_mp(
model="medium",
default_dtype=dtype,
device=device,
compile_mode=compile_mode,
fullgraph=False,
)
model = calc.models[0].to(device)
return model


def create_batch(size: int, model: torch.nn.Module, device: str) -> dict:
cutoff = model.r_max.item()
z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers])
atoms = build.bulk("C", "diamond", a=3.567, cubic=True)
atoms = atoms.repeat((size, size, size))
config = data.config_from_atoms(atoms)
dataset = [data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=dataset,
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader))
batch.to(device)
return batch.to_dict()


def log_bench_info(benchmark, dtype, compile_mode, batch):
benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0])
benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1])
benchmark.extra_info["dtype"] = dtype
benchmark.extra_info["is_compiled"] = compile_mode is not None
benchmark.extra_info["device_name"] = torch.cuda.get_device_name()


def read_bench_results(files: list[str]) -> pd.DataFrame:
def read(file):
with open(file, "r") as f:
data = json.load(f)

records = []
for bench in data["benchmarks"]:
record = {**bench["extra_info"], **bench["stats"]}
records.append(record)

df = pd.DataFrame(records)
df["ns/day (1 fs/step)"] = 0.086400 / df["median"]
df["Steps per day"] = df["ops"] * 86400
columns = [
"num_atoms",
"num_edges",
"dtype",
"is_compiled",
"device_name",
"median",
"Steps per day",
"ns/day (1 fs/step)",
]
return df[columns]

return pd.concat([read(f) for f in files])


if __name__ == "__main__":
# Print to stdout a csv of the benchmark metrics
import subprocess

result = subprocess.run(
["pytest-benchmark", "list"], capture_output=True, text=True
)

if result.returncode != 0:
raise RuntimeError(f"Command failed with return code {result.returncode}")

files = result.stdout.strip().split("\n")
df = read_bench_results(files)
print(df.to_csv(index=False))
40 changes: 40 additions & 0 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ase.io
import numpy as np
import pytest
import yaml
from ase.atoms import Atoms

pytest_mace_dir = Path(__file__).parent.parent
Expand Down Expand Up @@ -164,3 +165,42 @@ def test_preprocess_data(tmp_path, sample_configs):
np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8)

print("All checks passed successfully!")


def test_preprocess_config(tmp_path, sample_configs):
ase.io.write(tmp_path / "sample.xyz", sample_configs)

preprocess_params = {
"train_file": str(tmp_path / "sample.xyz"),
"r_max": 5.0,
"config_type_weights": "{'Default':1.0}",
"num_process": 2,
"valid_fraction": 0.1,
"h5_prefix": str(tmp_path / "preprocessed_"),
"compute_statistics": None,
"seed": 42,
"energy_key": "REF_energy",
"forces_key": "REF_forces",
"stress_key": "REF_stress",
}
filename = tmp_path / "config.yaml"
with open(filename, "w", encoding="utf-8") as file:
yaml.dump(preprocess_params, file)

run_env = os.environ.copy()
sys.path.insert(0, str(Path(__file__).parent.parent))
run_env["PYTHONPATH"] = ":".join(sys.path)
print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"])

cmd = (
sys.executable
+ " "
+ str(preprocess_data)
+ " "
+ "--config"
+ " "
+ str(filename)
)

p = subprocess.run(cmd.split(), env=run_env, check=True)
assert p.returncode == 0

0 comments on commit 2b30ec9

Please sign in to comment.