Skip to content

Commit

Permalink
Merge pull request #727 from ACEsuit/develop
Browse files Browse the repository at this point in the history
change learning rate for multihead ft
  • Loading branch information
ilyes319 authored Dec 5, 2024
2 parents 1ea5c55 + 2b30ec9 commit af4b739
Show file tree
Hide file tree
Showing 11 changed files with 333 additions and 17 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ dist/
*.xyz
/checkpoints
*.model

.benchmarks
4 changes: 2 additions & 2 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:

checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2")
else model
)

Expand Down Expand Up @@ -106,7 +106,7 @@ def mace_mp(
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
if model in (None, "small", "medium", "large") or str(model).startswith(
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith(
"https:"
):
model_path = download_mace_mp_checkpoint(model)
Expand Down
3 changes: 2 additions & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from e3nn import o3

from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
Expand Down Expand Up @@ -406,7 +407,7 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]

irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out))
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
Expand Down
10 changes: 10 additions & 0 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def run(args: argparse.Namespace) -> None:
args.E0s != "average"
), "average atomic energies cannot be used for multiheads finetuning"
# check that the foundation model has a single head, if not, use the first head
if not args.force_mh_ft_lr:
logging.info(
"Multihead finetuning mode, setting learning rate to 0.001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True."
)
args.lr = 0.001
args.ema = True
args.ema_decay = 0.999
logging.info(
"Using multiheads finetuning mode, setting learning rate to 0.001 and EMA to True"
)
if hasattr(model_foundation, "heads"):
if len(model_foundation.heads) > 1:
logging.warning(
Expand Down
4 changes: 3 additions & 1 deletion mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def load_from_xyz(
atoms_without_iso_atoms = []

for idx, atoms in enumerate(atoms_list):
atoms.info[head_key] = head_name
isolated_atom_config = (
len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom"
)
Expand All @@ -288,6 +287,9 @@ def load_from_xyz(
if not keep_isolated_atoms:
atoms_list = atoms_without_iso_atoms

for atoms in atoms_list:
atoms.info[head_key] = head_name

configs = config_from_atoms_list(
atoms_list,
config_type_weights=config_type_weights,
Expand Down
29 changes: 25 additions & 4 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--config",
type=str,
is_config_file=True,
help="config file to agregate options",
help="config file to aggregate options",
)
except ImportError:
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -379,6 +379,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=int,
default=1000,
)
parser.add_argument(
"--force_mh_ft_lr",
help="Force the multiheaded fine-tuning to use arg_parser lr",
type=str2bool,
default=False,
)
parser.add_argument(
"--subselect_pt",
help="Method to subselect the configurations of the pretraining set",
Expand Down Expand Up @@ -721,9 +727,24 @@ def build_default_arg_parser() -> argparse.ArgumentParser:


def build_preprocess_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
try:
import configargparse

parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
"--config",
type=str,
is_config_file=True,
help="config file to aggregate options",
)
except ImportError:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument(
"--train_file",
help="Training set h5 file",
Expand Down
13 changes: 9 additions & 4 deletions mace/tools/torch_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
from contextlib import contextmanager
from typing import Dict
from typing import Dict, Union

import numpy as np
import torch
Expand Down Expand Up @@ -129,13 +129,18 @@ def init_wandb(project: str, entity: str, name: str, config: dict, directory: st


@contextmanager
def default_dtype(dtype: torch.dtype):
def default_dtype(dtype: Union[torch.dtype, str]):
"""Context manager for configuring the default_dtype used by torch
Args:
dtype (torch.dtype): the default dtype to use within this context manager
dtype (torch.dtype|str): the default dtype to use within this context manager
"""
init = torch.get_default_dtype()
torch.set_default_dtype(dtype)
if isinstance(dtype, str):
set_default_dtype(dtype)
else:
torch.set_default_dtype(dtype)

yield

torch.set_default_dtype(init)
122 changes: 122 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from typing import Optional

import pandas as pd
import json
import pytest
import torch
from ase import build

from mace import data
from mace.calculators.foundations_models import mace_mp
from mace.tools import AtomicNumberTable, torch_geometric, torch_tools


def is_mace_full_bench():
return os.environ.get("MACE_FULL_BENCH", "0") == "1"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8)
@pytest.mark.parametrize("size", (3, 5, 7, 9))
@pytest.mark.parametrize("dtype", ["float32", "float64"])
@pytest.mark.parametrize("compile_mode", [None, "default"])
def test_inference(
benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda"
):
if not is_mace_full_bench() and compile_mode is not None:
pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute")

with torch_tools.default_dtype(dtype):
model = load_mace_mp_medium(dtype, compile_mode, device)
batch = create_batch(size, model, device)
log_bench_info(benchmark, dtype, compile_mode, batch)

def func():
torch.cuda.synchronize()
model(batch, training=compile_mode is not None, compute_force=True)

torch.cuda.empty_cache()
benchmark(func)


def load_mace_mp_medium(dtype, compile_mode, device):
calc = mace_mp(
model="medium",
default_dtype=dtype,
device=device,
compile_mode=compile_mode,
fullgraph=False,
)
model = calc.models[0].to(device)
return model


def create_batch(size: int, model: torch.nn.Module, device: str) -> dict:
cutoff = model.r_max.item()
z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers])
atoms = build.bulk("C", "diamond", a=3.567, cubic=True)
atoms = atoms.repeat((size, size, size))
config = data.config_from_atoms(atoms)
dataset = [data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=dataset,
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader))
batch.to(device)
return batch.to_dict()


def log_bench_info(benchmark, dtype, compile_mode, batch):
benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0])
benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1])
benchmark.extra_info["dtype"] = dtype
benchmark.extra_info["is_compiled"] = compile_mode is not None
benchmark.extra_info["device_name"] = torch.cuda.get_device_name()


def read_bench_results(files: list[str]) -> pd.DataFrame:
def read(file):
with open(file, "r") as f:
data = json.load(f)

records = []
for bench in data["benchmarks"]:
record = {**bench["extra_info"], **bench["stats"]}
records.append(record)

df = pd.DataFrame(records)
df["ns/day (1 fs/step)"] = 0.086400 / df["median"]
df["Steps per day"] = df["ops"] * 86400
columns = [
"num_atoms",
"num_edges",
"dtype",
"is_compiled",
"device_name",
"median",
"Steps per day",
"ns/day (1 fs/step)",
]
return df[columns]

return pd.concat([read(f) for f in files])


if __name__ == "__main__":
# Print to stdout a csv of the benchmark metrics
import subprocess

result = subprocess.run(
["pytest-benchmark", "list"], capture_output=True, text=True
)

if result.returncode != 0:
raise RuntimeError(f"Command failed with return code {result.returncode}")

files = result.stdout.strip().split("\n")
df = read_bench_results(files)
print(df.to_csv(index=False))
Loading

0 comments on commit af4b739

Please sign in to comment.