diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 9341ca8b..3f647906 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -1,6 +1,8 @@ import argparse + import torch from e3nn.util import jit + from mace.calculators import LAMMPS_MACE @@ -42,12 +44,11 @@ def select_head(model): if selected.isdigit() and 1 <= int(selected) <= len(heads): return heads[int(selected) - 1] - elif selected == "": + if selected == "": print("No head selected. Proceeding without specifying a head.") return None - else: - print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") - return heads[-1] + print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") + return heads[-1] def main(): diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 11cfbeb0..17b39667 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -41,12 +41,16 @@ dict_to_array, extract_config_mace_model, get_atomic_energies, + get_avg_num_neighbors, get_config_type_weights, get_dataset_from_xyz, get_files_with_suffix, get_loss_fn, + get_optimizer, + get_params_options, get_swa, print_git_commit, + setup_wandb, ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable @@ -194,6 +198,7 @@ def run(args: argparse.Namespace) -> None: head_config.config_type_weights ) collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=head_config.train_file, valid_path=head_config.valid_file, valid_fraction=head_config.valid_fraction, @@ -219,10 +224,10 @@ def run(args: argparse.Namespace) -> None: if all(head_config.train_file.endswith(".xyz") for head_config in head_configs): size_collections_train = sum( - [len(head_config.collections.train) for head_config in head_configs] + len(head_config.collections.train) for head_config in head_configs ) size_collections_valid = sum( - [len(head_config.collections.valid) for head_config in head_configs] + len(head_config.collections.valid) for head_config in head_configs ) if size_collections_train < args.batch_size: logging.error( @@ -257,6 +262,7 @@ def run(args: argparse.Namespace) -> None: else: heads = list(dict.fromkeys(["pt_head"] + heads)) collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=args.pt_train_file, valid_path=args.pt_valid_file, valid_fraction=args.valid_fraction, @@ -367,8 +373,9 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } + atomic_energies_dict_pt = atomic_energies_dict["pt_head"] logging.info( - f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" + f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict_pt[z]}' for z in z_table_foundation.zs])}" ) if args.model == "AtomicDipolesMACE": @@ -394,9 +401,14 @@ def run(args: argparse.Namespace) -> None: # [atomic_energies_dict[z] for z in z_table.zs] # ) atomic_energies = dict_to_array(atomic_energies_dict, heads) - logging.info( - f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}" + result = "\n".join( + [ + f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}" + for head_config in head_configs + ] ) + logging.info(result) valid_sets = {head: [] for head in heads} train_sets = {head: [] for head in heads} @@ -488,30 +500,7 @@ def run(args: argparse.Namespace) -> None: ) loss_fn = get_loss_fn(args, dipole_only, compute_dipole) - - if all(head_config.compute_avg_num_neighbors for head_config in head_configs): - logging.info("Computing average number of neighbors") - avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) - if args.distributed: - num_graphs = torch.tensor(len(train_loader.dataset)).to(device) - num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) - torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce( - num_neighbors, op=torch.distributed.ReduceOp.SUM - ) - args.avg_num_neighbors = (num_neighbors / num_graphs).item() - else: - args.avg_num_neighbors = avg_num_neighbors - else: - assert any(head_config.avg_num_neighbors is not None for head_config in head_configs), "Average number of neighbors must be provided in the configuration" - args.avg_num_neighbors = max(head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None) - - if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: - logging.warning( - f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}" - ) - else: - logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") + args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) # Selecting outputs compute_virials = False @@ -745,61 +734,9 @@ def run(args: argparse.Namespace) -> None: logging.info(loss_fn) # Optimizer - decay_interactions = {} - no_decay_interactions = {} - for name, param in model.interactions.named_parameters(): - if "linear.weight" in name or "skip_tp_full.weight" in name: - decay_interactions[name] = param - else: - no_decay_interactions[name] = param - - param_options = dict( - params=[ - { - "name": "embedding", - "params": model.node_embedding.parameters(), - "weight_decay": 0.0, - }, - { - "name": "interactions_decay", - "params": list(decay_interactions.values()), - "weight_decay": args.weight_decay, - }, - { - "name": "interactions_no_decay", - "params": list(no_decay_interactions.values()), - "weight_decay": 0.0, - }, - { - "name": "products", - "params": model.products.parameters(), - "weight_decay": args.weight_decay, - }, - { - "name": "readouts", - "params": model.readouts.parameters(), - "weight_decay": 0.0, - }, - ], - lr=args.lr, - amsgrad=args.amsgrad, - betas=(args.beta, 0.999), - ) - + param_options = get_params_options(args, model) optimizer: torch.optim.Optimizer - if args.optimizer == "adamw": - optimizer = torch.optim.AdamW(**param_options) - elif args.optimizer == "schedulefree": - try: - from schedulefree import adamw_schedulefree - except ImportError as exc: - raise ImportError( - "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" - ) from exc - _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} - optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) - else: - optimizer = torch.optim.Adam(**param_options) + optimizer = get_optimizer(args, param_options) if args.device == "xpu": logging.info("Optimzing model and optimzier for XPU") model, optimizer = ipex.optimize(model, optimizer=optimizer) @@ -846,27 +783,7 @@ def run(args: argparse.Namespace) -> None: group["lr"] = args.lr if args.wandb: - logging.info("Using Weights and Biases for logging") - import wandb - - wandb_config = {} - args_dict = vars(args) - - for key, value in args_dict.items(): - if isinstance(value, np.ndarray): - args_dict[key] = value.tolist() - - args_dict_json = json.dumps(args_dict) - for key in args.wandb_log_hypers: - wandb_config[key] = args_dict[key] - tools.init_wandb( - project=args.wandb_project, - entity=args.wandb_entity, - name=args.wandb_name, - config=wandb_config, - directory=args.wandb_dir, - ) - wandb.run.summary["params"] = args_dict_json + setup_wandb(args) if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) diff --git a/mace/modules/loss.py b/mace/modules/loss.py index 2d6522d2..91462d2c 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -273,8 +273,6 @@ def __init__( def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: num_atoms = ref.ptr[1:] - ref.ptr[:-1] - configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] - configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] return ( self.energy_weight * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index ce4172fc..8ad80243 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -28,7 +28,6 @@ compute_rel_rmse, compute_rmse, get_atomic_number_table_from_zs, - get_optimizer, get_tag, setup_logger, ) @@ -46,7 +45,6 @@ "setup_logger", "get_tag", "count_parameters", - "get_optimizer", "MetricsLogger", "get_atomic_number_table_from_zs", "train", diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 1e190da2..8892fa6d 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -162,6 +162,7 @@ def assemble_mp_data( } select_samples(dict_to_namespace(args_samples)) collections_mp, _ = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=f"mp_finetuning-{tag}.xyz", valid_path=None, valid_fraction=args.valid_fraction, diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 5a19439c..025b3453 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -20,6 +20,7 @@ from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules +from mace import tools from mace.tools import evaluate from mace.tools.train import SWAContainer @@ -349,6 +350,38 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: return atomic_energies_dict +def get_avg_num_neighbors(head_configs, args, train_loader, device): + if all(head_config.compute_avg_num_neighbors for head_config in head_configs): + logging.info("Computing average number of neighbors") + avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) + if args.distributed: + num_graphs = torch.tensor(len(train_loader.dataset)).to(device) + num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) + torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce( + num_neighbors, op=torch.distributed.ReduceOp.SUM + ) + avg_num_neighbors_out = (num_neighbors / num_graphs).item() + else: + avg_num_neighbors_out = avg_num_neighbors + else: + assert any( + head_config.avg_num_neighbors is not None for head_config in head_configs + ), "Average number of neighbors must be provided in the configuration" + avg_num_neighbors_out = max( + head_config.avg_num_neighbors + for head_config in head_configs + if head_config.avg_num_neighbors is not None + ) + if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100: + logging.warning( + f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}" + ) + else: + logging.info(f"Average number of neighbors: {avg_num_neighbors_out}") + return avg_num_neighbors_out + + def get_loss_fn( args: argparse.Namespace, dipole_only: bool, @@ -482,6 +515,95 @@ def get_swa( return swa, swas +def get_params_options( + args: argparse.Namespace, model: torch.nn.Module +) -> Dict[str, Any]: + decay_interactions = {} + no_decay_interactions = {} + for name, param in model.interactions.named_parameters(): + if "linear.weight" in name or "skip_tp_full.weight" in name: + decay_interactions[name] = param + else: + no_decay_interactions[name] = param + + param_options = dict( + params=[ + { + "name": "embedding", + "params": model.node_embedding.parameters(), + "weight_decay": 0.0, + }, + { + "name": "interactions_decay", + "params": list(decay_interactions.values()), + "weight_decay": args.weight_decay, + }, + { + "name": "interactions_no_decay", + "params": list(no_decay_interactions.values()), + "weight_decay": 0.0, + }, + { + "name": "products", + "params": model.products.parameters(), + "weight_decay": args.weight_decay, + }, + { + "name": "readouts", + "params": model.readouts.parameters(), + "weight_decay": 0.0, + }, + ], + lr=args.lr, + amsgrad=args.amsgrad, + betas=(args.beta, 0.999), + ) + return param_options + + +def get_optimizer( + args: argparse.Namespace, param_options: Dict[str, Any] +) -> torch.optim.Optimizer: + if args.optimizer == "adamw": + optimizer = torch.optim.AdamW(**param_options) + elif args.optimizer == "schedulefree": + try: + from schedulefree import adamw_schedulefree + except ImportError as exc: + raise ImportError( + "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" + ) from exc + _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} + optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) + else: + optimizer = torch.optim.Adam(**param_options) + return optimizer + + +def setup_wandb(args: argparse.Namespace): + logging.info("Using Weights and Biases for logging") + import wandb + + wandb_config = {} + args_dict = vars(args) + + for key, value in args_dict.items(): + if isinstance(value, np.ndarray): + args_dict[key] = value.tolist() + + args_dict_json = json.dumps(args_dict) + for key in args.wandb_log_hypers: + wandb_config[key] = args_dict[key] + tools.init_wandb( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name, + config=wandb_config, + directory=args.wandb_dir, + ) + wandb.run.summary["params"] = args_dict_json + + def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: return [ os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) diff --git a/mace/tools/utils.py b/mace/tools/utils.py index 762d9880..0d7aa41e 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -121,26 +121,6 @@ def atomic_numbers_to_indices( return to_index_fn(atomic_numbers) -def get_optimizer( - name: str, - amsgrad: bool, - learning_rate: float, - weight_decay: float, - parameters: Iterable[torch.Tensor], -) -> torch.optim.Optimizer: - if name == "adam": - return torch.optim.Adam( - parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay - ) - - if name == "adamw": - return torch.optim.AdamW( - parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay - ) - - raise RuntimeError(f"Unknown optimizer '{name}'") - - class UniversalEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, np.integer):