diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3dc5b609..70baab5c 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -558,15 +558,27 @@ def run(args: argparse.Namespace) -> None: ], lr=args.lr, amsgrad=args.amsgrad, + betas=(args.beta, 0.999), ) optimizer: torch.optim.Optimizer if args.optimizer == "adamw": optimizer = torch.optim.AdamW(**param_options) + elif args.optimizer == "schedulefree": + try: + from schedulefree import adamw_schedulefree + except ImportError as exc: + raise ImportError( + "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" + ) from exc + _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} + optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) else: optimizer = torch.optim.Adam(**param_options) - logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + "_train") + logger = tools.MetricsLogger( + directory=args.results_dir, tag=tag + "_train" + ) # pylint: disable=E1123 lr_scheduler = LRScheduler(optimizer, args) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 96f1e185..18ce98f5 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -446,7 +446,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: help="Optimizer for parameter optimization", type=str, default="adam", - choices=["adam", "adamw"], + choices=["adam", "adamw", "schedulefree"], + ) + parser.add_argument( + "--beta", + help="Beta parameter for the optimizer", + type=float, + default=0.9, ) parser.add_argument("--batch_size", help="batch size", type=int, default=10) parser.add_argument( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 2d75a7b1..0f8d72cf 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -167,10 +167,18 @@ def radial_to_transform(radial): "num_interactions": model.num_interactions.item(), "num_elements": len(model.atomic_numbers), "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), - "MLP_irreps": o3.Irreps(str(model.readouts[-1].hidden_irreps)), - "gate": model.readouts[-1] # pylint: disable=protected-access - .non_linearity._modules["acts"][0] - .f, + "MLP_irreps": ( + o3.Irreps(str(model.readouts[-1].hidden_irreps)) + if model.num_interactions.item() > 1 + else 1 + ), + "gate": ( + model.readouts[-1] # pylint: disable=protected-access + .non_linearity._modules["acts"][0] + .f + if model.num_interactions.item() > 1 + else None + ), "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), "avg_num_neighbors": model.interactions[0].avg_num_neighbors, "atomic_numbers": model.atomic_numbers, @@ -373,6 +381,9 @@ def custom_key(key): class LRScheduler: def __init__(self, optimizer, args) -> None: self.scheduler = args.scheduler + self._optimizer_type = ( + args.optimizer + ) # Schedulefree does not need an optimizer but checkpoint handler does. if args.scheduler == "ExponentialLR": self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer=optimizer, gamma=args.lr_scheduler_gamma @@ -387,6 +398,8 @@ def __init__(self, optimizer, args) -> None: raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") def step(self, metrics=None, epoch=None): # pylint: disable=E1123 + if self._optimizer_type == "schedulefree": + return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary if self.scheduler == "ExponentialLR": self.lr_scheduler.step(epoch=epoch) elif self.scheduler == "ReduceLROnPlateau": diff --git a/mace/tools/train.py b/mace/tools/train.py index a3f73ff9..7ebf3ce1 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -175,7 +175,8 @@ def train( # Train if distributed: train_sampler.set_epoch(epoch) - + if "ScheduleFree" in type(optimizer).__name__: + optimizer.train() train_one_epoch( model=model, loss_fn=loss_fn, @@ -201,6 +202,8 @@ def train( param_context = ( ema.average_parameters() if ema is not None else nullcontext() ) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.eval() with param_context: valid_loss, eval_metrics = evaluate( model=model_to_evaluate, diff --git a/setup.cfg b/setup.cfg index 1f9ca90e..81e0b661 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,3 +50,4 @@ dev = pre-commit pytest pylint +schedulefree = schedulefree \ No newline at end of file diff --git a/tests/test_schedulefree.py b/tests/test_schedulefree.py new file mode 100644 index 00000000..0d93b829 --- /dev/null +++ b/tests/test_schedulefree.py @@ -0,0 +1,127 @@ +from unittest.mock import MagicMock +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 + +from mace import data, modules, tools +from mace.tools import torch_geometric, scripts_utils +import tempfile + +try: + import schedulefree +except ImportError: + pytest.skip( + "Skipping schedulefree tests due to ImportError", allow_module_level=True + ) + +torch.set_default_dtype(torch.float64) + +table = tools.AtomicNumberTable([6]) +atomic_energies = np.array([1.0], dtype=float) +cutoff = 5.0 + + +def create_mace(device: str, seed: int = 1702): + torch_geometric.seed_everything(seed) + + model_config = { + "r_max": cutoff, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": o3.Irreps("8x0e + 8x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": atomic_energies, + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + } + model = modules.MACE(**model_config) + return model.to(device) + + +def create_batch(device: str): + from ase import build + + size = 2 + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms_list = [atoms.repeat((size, size, size))] + print("Number of atoms", len(atoms_list[0])) + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch = batch.to(device) + batch = batch.to_dict() + return batch + + +def do_optimization_step( + model, + optimizer, + device, +): + batch = create_batch(device) + model.train() + optimizer.train() + optimizer.zero_grad() + output = model(batch, training=True, compute_force=False) + loss = output["energy"].mean() + loss.backward() + optimizer.step() + model.eval() + optimizer.eval() + + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_can_load_checkpoint(device): + model = create_mace(device) + optimizer = schedulefree.adamw_schedulefree.AdamWScheduleFree(model.parameters()) + args = MagicMock() + args.optimizer = "schedulefree" + args.scheduler = "ExponentialLR" + args.lr_scheduler_gamma = 0.9 + lr_scheduler = scripts_utils.LRScheduler(optimizer, args) + with tempfile.TemporaryDirectory() as d: + checkpoint_handler = tools.CheckpointHandler( + directory=d, keep=False, tag="schedulefree" + ) + for _ in range(10): + do_optimization_step(model, optimizer, device) + batch = create_batch(device) + output = model(batch) + energy = output["energy"].detach().cpu().numpy() + + state = tools.CheckpointState( + model=model, optimizer=optimizer, lr_scheduler=lr_scheduler + ) + checkpoint_handler.save(state, epochs=0, keep_last=False) + checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + ) + batch = create_batch(device) + output = model(batch) + new_energy = output["energy"].detach().cpu().numpy() + assert np.allclose(energy, new_energy, atol=1e-9)