Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
argparse support schedulefree

support specifying betas

fix single layer logging

Add optional dependency and a loading test
  • Loading branch information
RokasEl committed Jun 20, 2024
1 parent 7842e99 commit 8ca8f09
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 7 deletions.
14 changes: 13 additions & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,15 +558,27 @@ def run(args: argparse.Namespace) -> None:
],
lr=args.lr,
amsgrad=args.amsgrad,
betas=(args.beta, 0.999),
)

optimizer: torch.optim.Optimizer
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(**param_options)
elif args.optimizer == "schedulefree":
try:
from schedulefree import adamw_schedulefree
except ImportError as exc:
raise ImportError(
"`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`"
) from exc
_param_options = {k: v for k, v in param_options.items() if k != "amsgrad"}
optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options)
else:
optimizer = torch.optim.Adam(**param_options)

logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + "_train")
logger = tools.MetricsLogger(
directory=args.results_dir, tag=tag + "_train"
) # pylint: disable=E1123

lr_scheduler = LRScheduler(optimizer, args)

Expand Down
8 changes: 7 additions & 1 deletion mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
help="Optimizer for parameter optimization",
type=str,
default="adam",
choices=["adam", "adamw"],
choices=["adam", "adamw", "schedulefree"],
)
parser.add_argument(
"--beta",
help="Beta parameter for the optimizer",
type=float,
default=0.9,
)
parser.add_argument("--batch_size", help="batch size", type=int, default=10)
parser.add_argument(
Expand Down
21 changes: 17 additions & 4 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,18 @@ def radial_to_transform(radial):
"num_interactions": model.num_interactions.item(),
"num_elements": len(model.atomic_numbers),
"hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)),
"MLP_irreps": o3.Irreps(str(model.readouts[-1].hidden_irreps)),
"gate": model.readouts[-1] # pylint: disable=protected-access
.non_linearity._modules["acts"][0]
.f,
"MLP_irreps": (
o3.Irreps(str(model.readouts[-1].hidden_irreps))
if model.num_interactions.item() > 1
else 1
),
"gate": (
model.readouts[-1] # pylint: disable=protected-access
.non_linearity._modules["acts"][0]
.f
if model.num_interactions.item() > 1
else None
),
"atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(),
"avg_num_neighbors": model.interactions[0].avg_num_neighbors,
"atomic_numbers": model.atomic_numbers,
Expand Down Expand Up @@ -373,6 +381,9 @@ def custom_key(key):
class LRScheduler:
def __init__(self, optimizer, args) -> None:
self.scheduler = args.scheduler
self._optimizer_type = (
args.optimizer
) # Schedulefree does not need an optimizer but checkpoint handler does.
if args.scheduler == "ExponentialLR":
self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer=optimizer, gamma=args.lr_scheduler_gamma
Expand All @@ -387,6 +398,8 @@ def __init__(self, optimizer, args) -> None:
raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'")

def step(self, metrics=None, epoch=None): # pylint: disable=E1123
if self._optimizer_type == "schedulefree":
return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary
if self.scheduler == "ExponentialLR":
self.lr_scheduler.step(epoch=epoch)
elif self.scheduler == "ReduceLROnPlateau":
Expand Down
5 changes: 4 additions & 1 deletion mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def train(
# Train
if distributed:
train_sampler.set_epoch(epoch)

if "ScheduleFree" in type(optimizer).__name__:
optimizer.train()
train_one_epoch(
model=model,
loss_fn=loss_fn,
Expand All @@ -201,6 +202,8 @@ def train(
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
if "ScheduleFree" in type(optimizer).__name__:
optimizer.eval()
with param_context:
valid_loss, eval_metrics = evaluate(
model=model_to_evaluate,
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ dev =
pre-commit
pytest
pylint
schedulefree = schedulefree
127 changes: 127 additions & 0 deletions tests/test_schedulefree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from e3nn import o3

from mace import data, modules, tools
from mace.tools import torch_geometric, scripts_utils
import tempdir

try:
import schedulefree
except ImportError:
pytest.skip(
"Skipping schedulefree tests due to ImportError", allow_module_level=True
)

torch.set_default_dtype(torch.float64)

table = tools.AtomicNumberTable([6])
atomic_energies = np.array([1.0], dtype=float)
cutoff = 5.0


def create_mace(device: str, seed: int = 1702):
torch_geometric.seed_everything(seed)

model_config = {
"r_max": cutoff,
"num_bessel": 8,
"num_polynomial_cutoff": 6,
"max_ell": 3,
"interaction_cls": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"interaction_cls_first": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"num_interactions": 2,
"num_elements": 1,
"hidden_irreps": o3.Irreps("8x0e + 8x1o"),
"MLP_irreps": o3.Irreps("16x0e"),
"gate": F.silu,
"atomic_energies": atomic_energies,
"avg_num_neighbors": 8,
"atomic_numbers": table.zs,
"correlation": 3,
"radial_type": "bessel",
}
model = modules.MACE(**model_config)
return model.to(device)


def create_batch(device: str):
from ase import build

size = 2
atoms = build.bulk("C", "diamond", a=3.567, cubic=True)
atoms_list = [atoms.repeat((size, size, size))]
print("Number of atoms", len(atoms_list[0]))

configs = [data.config_from_atoms(atoms) for atoms in atoms_list]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(config, z_table=table, cutoff=cutoff)
for config in configs
],
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader))
batch = batch.to(device)
batch = batch.to_dict()
return batch


def do_optimization_step(
model,
optimizer,
device,
):
batch = create_batch(device)
model.train()
optimizer.train()
optimizer.zero_grad()
output = model(batch, training=True, compute_force=False)
loss = output["energy"].mean()
loss.backward()
optimizer.step()
model.eval()
optimizer.eval()



@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_can_load_checkpoint(device):
model = create_mace(device)
optimizer = schedulefree.adamw_schedulefree.AdamWScheduleFree(model.parameters())
args = MagicMock()
args.optimizer = "schedulefree"
args.scheduler = "ExponentialLR"
args.lr_scheduler_gamma = 0.9
lr_scheduler = scripts_utils.LRScheduler(optimizer, args)
with tempdir.TempDir() as d:
checkpoint_handler = tools.CheckpointHandler(
directory=d, keep=False, tag="schedulefree"
)
for _ in range(10):
do_optimization_step(model, optimizer, device)
batch = create_batch(device)
output = model(batch)
energy = output["energy"].detach().cpu().numpy()

state = tools.CheckpointState(
model=model, optimizer=optimizer, lr_scheduler=lr_scheduler
)
checkpoint_handler.save(state, epochs=0, keep_last=False)
checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
)
batch = create_batch(device)
output = model(batch)
new_energy = output["energy"].detach().cpu().numpy()
assert np.allclose(energy, new_energy, atol=1e-9)

0 comments on commit 8ca8f09

Please sign in to comment.