diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 575d9b3f..68432a5a 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -256,7 +256,7 @@ def main() -> None: head_args.valid_set = data.dataset_from_sharded_hdf5( head_args.valid_file, r_max=head_args.r_max, z_table=z_table, head=head, heads=list(args.heads.keys()), rank=rank ) - + # subset train ratio if "train_ratio" in head_args.keys(): ratio = head_args.train_ratio @@ -428,11 +428,13 @@ def main() -> None: huber_delta=args.huber_delta, ) elif args.loss == "universal": + head_stress_mask = torch.Tensor([float('mp' in k) for k in args.heads.keys()]).to(device=device) # TODO: make it general loss_fn = modules.UniversalLoss( energy_weight=args.energy_weight, forces_weight=args.forces_weight, stress_weight=args.stress_weight, huber_delta=args.huber_delta, + head_stress_mask=head_stress_mask ) elif args.loss == "dipole": assert ( @@ -458,7 +460,9 @@ def main() -> None: if args.loss in ("stress", "virials", "huber", "universal"): compute_virials = True args.compute_stress = True - args.error_table = "PerAtomRMSEstressvirials" + # args.error_table = "PerAtomRMSEstressvirials" + logging.info(f"Over-wrighting the error table due to the loss setting -> {args.loss} loss") + args.error_table = "PerAtomRMSE+EMAEstressvirials" output_args = { "energy": compute_energy, diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 1087bacb..338f5e9c 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -57,6 +57,41 @@ def forward( return self.linear(x) # [n_nodes, 1] +class GroupavgReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, + gate: Optional[Callable], + irrep_out: o3.Irreps=o3.Irreps("0e"), + layered: int=2, # choice from [0, 1] + resolution: int=2, # choice form [0, 1, 2] + ): + super().__init__() + self.irreps_in = irreps_in + self.non_linearity = gate + input_size = irreps_in.dim + output_size = irrep_out.dim + hidden_size = 128 + self.MLP = torch.nn.Sequential( + torch.nn.Linear(input_size, hidden_size), + torch.nn.BatchNorm1d(hidden_size), + torch.nn.SiLU(), + torch.nn.Linear(hidden_size, output_size) + ) + self.layered = layered + self.resolution = resolution + self.register_buffer("SO3_grid", + o3.quaternion_to_matrix( + torch.load(f"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_{layered}_{resolution}.pt").to(torch.get_default_dtype()))) + + def forward(self, x: torch.Tensor, heads: Optional[torch.Tensor] = None): + rand_D = o3.rand_matrix(device=x.device) + gs = self.SO3_grid_2_2 @ rand_D # [72, 3, 3] + Ds = self.irreps_in.D_from_matrix(gs) # [72, D, D] + + xs = torch.einsum("nd,rjd->nrj", x, Ds) # [n_graphs, D], [72, D, D] -> [n_graphs, 72, D] + outs = self.MLP(xs.view(-1, xs.size(-1))) # [n_graph, 72, 1] + out = torch.mean(outs.view(*xs.shape[:-1], -1), dim=1, keepdim=False) + return out + @simplify_if_compile @compile_mode("script") class NonLinearReadoutBlock(torch.nn.Module): @@ -80,7 +115,7 @@ def forward( ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.non_linearity(self.linear_1(x)) if hasattr(self, "num_heads") and self.num_heads > 1 and heads is not None: - x = mask_head(x, heads, self.num_heads) + x = mask_head(x, heads, self.num_heads) # decorrelate two mlps return self.linear_2(x) # [n_nodes, len(heads)] diff --git a/mace/modules/loss.py b/mace/modules/loss.py index b3421ef5..dbf67ff4 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -253,7 +253,8 @@ def __repr__(self): class UniversalLoss(torch.nn.Module): def __init__( - self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01, + head_stress_mask=None ) -> None: super().__init__() self.huber_delta = huber_delta @@ -270,16 +271,27 @@ def __init__( "stress_weight", torch.tensor(stress_weight, dtype=torch.get_default_dtype()), ) + self.head_stress_mask=head_stress_mask def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: num_atoms = ref.ptr[1:] - ref.ptr[:-1] - return ( - self.energy_weight - * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) - + self.forces_weight - * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) - + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) - ) + if self.head_stress_mask is None: + return ( + self.energy_weight + * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + + self.forces_weight + * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) + + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) + ) + else: + stress_musk = self.head_stress_mask[ref.head].view(-1, 1, 1) + return ( + self.energy_weight + * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + + self.forces_weight + * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) + + self.stress_weight * self.huber_loss(ref["stress"] * stress_musk, pred["stress"] * stress_musk) + ) def __repr__(self): return ( diff --git a/mace/modules/models.py b/mace/modules/models.py index b6018892..36a3260d 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -385,6 +385,7 @@ def forward( # Interactions node_es_list = [pair_node_energy] node_feats_list = [] + # import ipdb; ipdb.set_trace() for interaction, product, readout in zip( self.interactions, self.products, self.readouts ): diff --git a/mace/modules/test_grpavg_readout.ipynb b/mace/modules/test_grpavg_readout.ipynb new file mode 100644 index 00000000..1d9c793d --- /dev/null +++ b/mace/modules/test_grpavg_readout.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 188, + "metadata": {}, + "outputs": [], + "source": [ + "from abc import abstractmethod\n", + "from typing import Callable, List, Optional, Tuple, Union\n", + "\n", + "import numpy as np\n", + "from torch.nn.functional import silu\n", + "from e3nn import nn, o3\n", + "from e3nn.util.jit import compile_mode\n", + "import torch\n", + "# Set the default floating-point type to float64\n", + "torch.set_default_dtype(torch.float64)\n", + "\n", + "class GroupavgReadoutBlock(torch.nn.Module):\n", + "\n", + " def __init__(self, irreps_in: o3.Irreps,\n", + " gate: Optional[Callable],\n", + " irrep_out: o3.Irreps=o3.Irreps(\"0e\"),\n", + " ):\n", + " super().__init__()\n", + " self.irreps_in = irreps_in\n", + " self.non_linearity = gate\n", + " input_size = irreps_in.dim\n", + " output_size = irrep_out.dim\n", + " hidden_size = 128\n", + " self.MLP = torch.nn.Sequential(\n", + " torch.nn.Linear(input_size, hidden_size),\n", + " torch.nn.BatchNorm1d(hidden_size),\n", + " torch.nn.SiLU(),\n", + " torch.nn.Linear(hidden_size, output_size)\n", + " )\n", + " self.register_buffer(\"SO3_grid_1_0\", \n", + " o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_0.pt\").to(torch.get_default_dtype())))\n", + " self.register_buffer(\"SO3_grid_1_1\", \n", + " o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_1.pt\").to(torch.get_default_dtype())))\n", + " self.register_buffer(\"SO3_grid_1_2\", \n", + " o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_2.pt\").to(torch.get_default_dtype())))\n", + " self.register_buffer(\"SO3_grid_2_0\", \n", + " o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_0.pt\").to(torch.get_default_dtype())))\n", + " self.register_buffer(\"SO3_grid_2_1\", \n", + " o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_1.pt\").to(torch.get_default_dtype())))\n", + " self.register_buffer(\"SO3_grid_2_2\", \n", + " o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_2.pt\").to(torch.get_default_dtype())))\n", + "\n", + "\n", + " def forward(self, x: torch.Tensor, heads: Optional[torch.Tensor] = None):\n", + " rand_D = o3.rand_matrix(device=x.device)\n", + " gs = self.SO3_grid_1_2 @ rand_D # [72, 3, 3]\n", + " Ds = self.irreps_in.D_from_matrix(gs) # [72, D, D]\n", + "\n", + " xs = torch.einsum(\"nd,rjd->nrj\", x, Ds) # [n_graphs, D], [72, D, D] -> [n_graphs, 72, D]\n", + " print(xs.shape)\n", + " outs = self.MLP(xs.view(-1, xs.size(-1))) # [n_graph, 72, 1]\n", + " out = torch.mean(outs.view(*xs.shape[:-1], -1), dim=1, keepdim=False)\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 189, + "metadata": {}, + "outputs": [], + "source": [ + "irreps_in = o3.Irreps(\"3x0e+1x1o+1x2e\")\n", + "n_graph = 32\n", + "readout = GroupavgReadoutBlock(irreps_in=irreps_in, gate=torch.nn.SiLU)" + ] + }, + { + "cell_type": "code", + "execution_count": 205, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([32, 4608, 11])\n", + "torch.Size([32, 4608, 11])\n", + "tensor(0.0003, grad_fn=)\n" + ] + } + ], + "source": [ + "x = irreps_in.randn(n_graph, -1)\n", + "\n", + "out = readout(x)\n", + "\n", + "rot_x = x @ irreps_in.D_from_matrix(o3.rand_matrix())\n", + "\n", + "rot_out = readout(rot_x)\n", + "# print(x - rot_x)\n", + "print((rot_out - out).abs().mean())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mace/modules/utils.py b/mace/modules/utils.py index ebdabae5..1ac47c09 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -267,8 +267,8 @@ def compute_mean_rms_energy_forces( forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } head = torch.cat(head_list, dim=0) # [total_n_graphs] head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] - - mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0)) rms = to_numpy( torch.sqrt( scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) diff --git a/mace/tools/train.py b/mace/tools/train.py index 8823b248..c307d2f6 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -72,6 +72,21 @@ def valid_err_log( logging.info( f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress={error_stress:.1f} meV / A^3" ) + elif ( + log_errors == "PerAtomRMSE+EMAEstressvirials" and eval_metrics["rmse_stress"] is not None + ): + error_e_rmse = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f_rmse = eval_metrics["rmse_f"] * 1e3 + error_stress_rmse = eval_metrics["rmse_stress"] * 1e3 + error_e_mae = eval_metrics["mae_e_per_atom"] * 1e3 + error_f_mae = eval_metrics["mae_f"] * 1e3 + error_stress_mae = eval_metrics["mae_stress"] * 1e3 + logging.info( + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, \t RMSE_E_per_atom={error_e_rmse:.1f} meV, RMSE_F={error_f_rmse:.1f} meV / A, RMSE_stress={error_stress_rmse:.1f} meV / A^3" + ) + logging.info( + f" \t MAE_E_per_atom={error_e_mae:.1f} meV, MAE_F={error_f_mae:.1f} meV / A, MAE_stress={error_stress_mae:.1f} meV / A^3" + ) elif ( log_errors == "PerAtomRMSEstressvirials" and eval_metrics["rmse_virials_per_atom"] is not None diff --git a/multihead_config/jz_mp_config_r6.0.yaml b/multihead_config/jz_mp_config_r6.0.yaml new file mode 100644 index 00000000..49652653 --- /dev/null +++ b/multihead_config/jz_mp_config_r6.0.yaml @@ -0,0 +1,15 @@ +avg_num_neighbor_head: mp_pbe +device: cuda +multi_processed_test: True +heads: + mp_pbe: + train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/MatProj + valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/MatProj + E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json + config_type_weights: + Default: 1.0 + avg_num_neighbors: 61.9649349317854 + mean: 0.1634233391135065 + std: 0.7735790334431056 +#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice +#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json diff --git a/multihead_config/jz_oc_mp_config_r6.0.yaml b/multihead_config/jz_oc_mp_config_r6.0.yaml new file mode 100644 index 00000000..f07f98cb --- /dev/null +++ b/multihead_config/jz_oc_mp_config_r6.0.yaml @@ -0,0 +1,25 @@ +avg_num_neighbor_head: mp_pbe +device: cuda +multi_processed_test: True +heads: + spice_wB97M: + train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/spice + valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/spice + E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json + config_type_weights: + Default: 1.0 + avg_num_neighbors: 22.86736849018836 + mean: -4.406405198254238 + std: 1.0737544472166278 + + mp_pbe: + train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/MatProj + valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/MatProj + E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json + config_type_weights: + Default: 1.0 + avg_num_neighbors: 61.9649349317854 + mean: 0.1634233391135065 + std: 0.7735790334431056 +#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice +#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json diff --git a/multihead_config/jz_spice_mp_config.yaml b/multihead_config/jz_spice_mp_config.yaml index c1cfc25c..309ba4d9 100644 --- a/multihead_config/jz_spice_mp_config.yaml +++ b/multihead_config/jz_spice_mp_config.yaml @@ -21,5 +21,8 @@ heads: avg_num_neighbors: 35.985167534166 mean: -4.48071865 std: 0.77357903 -#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice -#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json + +# test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice +# statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json +# mean, std does not depend on r +# no online, compute statistics script with the same yaml. \ No newline at end of file diff --git a/multihead_config/jz_spice_mp_config_r6.0.yaml b/multihead_config/jz_spice_mp_config_r6.0.yaml new file mode 100644 index 00000000..f07f98cb --- /dev/null +++ b/multihead_config/jz_spice_mp_config_r6.0.yaml @@ -0,0 +1,25 @@ +avg_num_neighbor_head: mp_pbe +device: cuda +multi_processed_test: True +heads: + spice_wB97M: + train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/spice + valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/spice + E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json + config_type_weights: + Default: 1.0 + avg_num_neighbors: 22.86736849018836 + mean: -4.406405198254238 + std: 1.0737544472166278 + + mp_pbe: + train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/MatProj + valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/MatProj + E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json + config_type_weights: + Default: 1.0 + avg_num_neighbors: 61.9649349317854 + mean: 0.1634233391135065 + std: 0.7735790334431056 +#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice +#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json diff --git a/n1gpu4_train.slurm b/n1gpu4_train.slurm index 47ce1204..c8241da7 100644 --- a/n1gpu4_train.slurm +++ b/n1gpu4_train.slurm @@ -3,18 +3,24 @@ #SBATCH --account=gax@h100 # account #SBATCH -C h100 # target H100 nodes # Here, reservation of 3x24=72 CPUs (for 3 tasks) and 3 GPUs (1 GPU per task) on one node only: -#SBATCH --nodes=1 # number of node +#SBATCH --nodes=8 # number of node #SBATCH --ntasks-per-node=4 # number of MPI tasks per node (here = number of GPUs per node) #SBATCH --gres=gpu:4 # number of GPUs per node (max 4 for H100 nodes) # Knowing that here we only reserve one GPU per task (i.e. 1/4 of GPUs), # the ideal is to reserve 1/4 of CPUs for each task: -#SBATCH --cpus-per-task=16 # number of CPUs per task (here 1/4 of the node) +#SBATCH --cpus-per-task=8 # number of CPUs per task (here 1/4 of the node) # /!\ Caution, "multithread" in Slurm vocabulary refers to hyperthreading. #SBATCH --hint=nomultithread # hyperthreading deactivated #SBATCH --time=20:00:00 # maximum execution time requested (HH:MM:SS) -#SBATCH --output=n1gpu4-mace-train-%j.out # name of output file -#SBATCH --error=n1gpu4-mace-train-%j.out # name of error file (here, in common with the output file) - +#SBATCH --output=n1gpu4-mace-train-%A_%a.out # name of output file +#SBATCH --error=n1gpu4-mace-train-%A_%a.out # name of error file (here, in common with the output file) +#SBATCH --array=0-3%1 # Array index range + +# Access arguments +bs=16 +lr=0.005 +gpu=32 + # Cleans out modules loaded in interactive and inherited by default module purge @@ -29,7 +35,7 @@ export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset # Running code -srun bash run_multihead.sh +srun bash run_multihead_mpb.sh ${bs} ${lr} #mace_run_train \ # --name="Test_Multihead_MultiGPU_SpiceMP_MACE" \ diff --git a/n1gpu4_train_mponly.slurm b/n1gpu4_train_mponly.slurm new file mode 100644 index 00000000..0bafc2bf --- /dev/null +++ b/n1gpu4_train_mponly.slurm @@ -0,0 +1,72 @@ +#!/bin/bash +#SBATCH --job-name=train-mace # job name +#SBATCH --account=gax@h100 # account +#SBATCH -C h100 # target H100 nodes +# Here, reservation of 3x24=72 CPUs (for 3 tasks) and 3 GPUs (1 GPU per task) on one node only: +#SBATCH --nodes=2 # number of node +#SBATCH --ntasks-per-node=4 # number of MPI tasks per node (here = number of GPUs per node) +#SBATCH --gres=gpu:4 # number of GPUs per node (max 4 for H100 nodes) +# Knowing that here we only reserve one GPU per task (i.e. 1/4 of GPUs), +# the ideal is to reserve 1/4 of CPUs for each task: +#SBATCH --cpus-per-task=8 # number of CPUs per task (here 1/4 of the node) +# /!\ Caution, "multithread" in Slurm vocabulary refers to hyperthreading. +#SBATCH --hint=nomultithread # hyperthreading deactivated +#SBATCH --time=20:00:00 # maximum execution time requested (HH:MM:SS) +#SBATCH --output=n1gpu4-mace-train-%A_%a.out # name of output file +#SBATCH --error=n1gpu4-mace-train-%A_%a.out # name of error file (here, in common with the output file) +#SBATCH --array=0-3%1 # Array index range + +# Access arguments +bs=64 +lr=0.005 + +# Cleans out modules loaded in interactive and inherited by default +module purge + +# Loading modules +module load pytorch-gpu/py3/2.3.1 + +# Echo of launched commands +set -x + +# set path +export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" +DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset + +# Running code +srun bash run_multihead_mpb_mponly.sh $bs $lr + +#mace_run_train \ +# --name="Test_Multihead_MultiGPU_SpiceMP_MACE" \ +# --model="MACE" \ +# --num_interactions=2 \ +# --num_channels=224 \ +# --max_L=0 \ +# --correlation=3 \ +# --r_max=5.0 \ +# --forces_weight=1000 \ +# --energy_weight=40 \ +# --weight_decay=5e-10 \ +# --clip_grad=1.0 \ +# --batch_size=32 \ +# --valid_batch_size=128 \ +# --max_num_epochs=210 \ +# --patience=50 \ +# --eval_interval=1 \ +# --ema \ +# --num_workers=8 \ +# --error_table='PerAtomMAE' \ +# --default_dtype="float64"\ +# --device=cuda \ +# --seed=123 \ +# --save_cpu \ +# --restart_latest \ +# --loss="weighted" \ +# --scheduler_patience=20 \ +# --lr=0.01 \ +# --swa \ +# --swa_lr=0.00025 \ +# --swa_forces_weight=100 \ +# --start_swa=190 \ +# --config="multihead_config/jz_spice_mp_config.yaml" \ +# --distributed \ diff --git a/run_multihead.sh b/run_multihead.sh index c965e3a7..3bde22e7 100644 --- a/run_multihead.sh +++ b/run_multihead.sh @@ -6,7 +6,7 @@ mace_run_train \ --model="MACE" \ --num_interactions=2 \ --num_channels=224 \ - --max_L=0 \ + --max_L=2 \ --correlation=3 \ --r_max=5.0 \ --forces_weight=1000 \ @@ -22,7 +22,6 @@ mace_run_train \ --num_workers=8 \ --error_table='PerAtomMAE' \ --default_dtype="float64"\ - --device=cuda \ --seed=0 \ --save_cpu \ --restart_latest \ @@ -34,6 +33,8 @@ mace_run_train \ --swa_forces_weight=100 \ --start_swa=190 \ --config="multihead_config/jz_spice_mp_config.yaml" \ + --device=cuda \ --distributed \ + # seed 0 for test, seed 123 for first run diff --git a/run_multihead_mpb.sh b/run_multihead_mpb.sh new file mode 100755 index 00000000..7ffe1f38 --- /dev/null +++ b/run_multihead_mpb.sh @@ -0,0 +1,49 @@ +#!/bin/bash +DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset +module load pytorch-gpu/py3/2.3.1 +export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" +mace_run_train \ + --name="MACE_medium_agnesi_b$1_lr$2" \ + --loss='universal' \ + --energy_weight=1 \ + --forces_weight=10 \ + --compute_stress=True \ + --stress_weight=10 \ + --eval_interval=1 \ + --error_table='PerAtomMAE' \ + --model="MACE" \ + --interaction_first="RealAgnosticInteractionBlock" \ + --interaction="RealAgnosticResidualInteractionBlock" \ + --num_interactions=2 \ + --correlation=3 \ + --max_ell=3 \ + --r_max=6.0 \ + --max_L=1 \ + --num_channels=128 \ + --num_radial_basis=10 \ + --MLP_irreps="16x0e" \ + --scaling='rms_forces_scaling' \ + --lr=$2 \ + --weight_decay=1e-8 \ + --ema \ + --ema_decay=0.995 \ + --scheduler_patience=5 \ + --batch_size=$1 \ + --valid_batch_size=32 \ + --pair_repulsion \ + --distance_transform="Agnesi" \ + --max_num_epochs=100 \ + --patience=40 \ + --amsgrad \ + --seed=1 \ + --clip_grad=100 \ + --keep_checkpoints \ + --restart_latest \ + --save_cpu \ + --config="multihead_config/jz_spice_mp_config_r6.0.yaml" \ + --device=cuda \ + --num_workers=8 \ + --distributed \ + + +# --name="MACE_medium_agnesi_b32_origin_mponly" \ diff --git a/run_multihead_mpb_3args.sh b/run_multihead_mpb_3args.sh new file mode 100755 index 00000000..c5782ce0 --- /dev/null +++ b/run_multihead_mpb_3args.sh @@ -0,0 +1,50 @@ +#!/bin/bash +DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset +module load pytorch-gpu/py3/2.3.1 +export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" +REAL_BATCH_SIZE=$(($1 * $3)) +mace_run_train \ + --name="MACE_medium_agnesi_b${REAL_BATCH_SIZE}_lr$2" \ + --loss='universal' \ + --energy_weight=1 \ + --forces_weight=10 \ + --compute_stress=True \ + --stress_weight=10 \ + --eval_interval=1 \ + --error_table='PerAtomMAE' \ + --model="MACE" \ + --interaction_first="RealAgnosticInteractionBlock" \ + --interaction="RealAgnosticResidualInteractionBlock" \ + --num_interactions=2 \ + --correlation=3 \ + --max_ell=3 \ + --r_max=6.0 \ + --max_L=1 \ + --num_channels=128 \ + --num_radial_basis=10 \ + --MLP_irreps="16x0e" \ + --scaling='rms_forces_scaling' \ + --lr=$2 \ + --weight_decay=1e-8 \ + --ema \ + --ema_decay=0.995 \ + --scheduler_patience=5 \ + --batch_size=$1 \ + --valid_batch_size=32 \ + --pair_repulsion \ + --distance_transform="Agnesi" \ + --max_num_epochs=100 \ + --patience=40 \ + --amsgrad \ + --seed=1 \ + --clip_grad=100 \ + --keep_checkpoints \ + --restart_latest \ + --save_cpu \ + --config="multihead_config/jz_spice_mp_config_r6.0.yaml" \ + --device=cuda \ + --num_workers=8 \ + --distributed \ + + +# --name="MACE_medium_agnesi_b32_origin_mponly" \ diff --git a/run_multihead_mpb_3args_mponly.sh b/run_multihead_mpb_3args_mponly.sh new file mode 100755 index 00000000..6c810874 --- /dev/null +++ b/run_multihead_mpb_3args_mponly.sh @@ -0,0 +1,50 @@ +#!/bin/bash +DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset +module load pytorch-gpu/py3/2.3.1 +export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" +REAL_BATCH_SIZE=$(($1 * $3)) +mace_run_train \ + --name="MACE_medium_agnesi_b${REAL_BATCH_SIZE}_lr$2_mponly" \ + --loss='universal' \ + --energy_weight=1 \ + --forces_weight=10 \ + --compute_stress=True \ + --stress_weight=10 \ + --eval_interval=1 \ + --error_table='PerAtomMAE' \ + --model="MACE" \ + --interaction_first="RealAgnosticInteractionBlock" \ + --interaction="RealAgnosticResidualInteractionBlock" \ + --num_interactions=2 \ + --correlation=3 \ + --max_ell=3 \ + --r_max=6.0 \ + --max_L=1 \ + --num_channels=128 \ + --num_radial_basis=10 \ + --MLP_irreps="16x0e" \ + --scaling='rms_forces_scaling' \ + --lr=$2 \ + --weight_decay=1e-8 \ + --ema \ + --ema_decay=0.995 \ + --scheduler_patience=5 \ + --batch_size=$1 \ + --valid_batch_size=32 \ + --pair_repulsion \ + --distance_transform="Agnesi" \ + --max_num_epochs=100 \ + --patience=40 \ + --amsgrad \ + --seed=1 \ + --clip_grad=100 \ + --keep_checkpoints \ + --restart_latest \ + --save_cpu \ + --config="multihead_config/jz_mp_config_r6.0.yaml" \ + --device=cuda \ + --num_workers=8 \ + --distributed \ + + +# --name="MACE_medium_agnesi_b32_origin_mponly" \ diff --git a/run_multihead_mpb_mponly.sh b/run_multihead_mpb_mponly.sh new file mode 100755 index 00000000..7fb665dd --- /dev/null +++ b/run_multihead_mpb_mponly.sh @@ -0,0 +1,49 @@ +#!/bin/bash +DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset +module load pytorch-gpu/py3/2.3.1 +export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" +mace_run_train \ + --name="MACE_medium_agnesi_b$1_lr$2_mponly" \ + --loss='universal' \ + --energy_weight=1 \ + --forces_weight=10 \ + --compute_stress=True \ + --stress_weight=10 \ + --eval_interval=1 \ + --error_table='PerAtomMAE' \ + --model="MACE" \ + --interaction_first="RealAgnosticInteractionBlock" \ + --interaction="RealAgnosticResidualInteractionBlock" \ + --num_interactions=2 \ + --correlation=3 \ + --max_ell=3 \ + --r_max=6.0 \ + --max_L=1 \ + --num_channels=128 \ + --num_radial_basis=10 \ + --MLP_irreps="16x0e" \ + --scaling='rms_forces_scaling' \ + --lr=$2 \ + --weight_decay=1e-8 \ + --ema \ + --ema_decay=0.995 \ + --scheduler_patience=5 \ + --batch_size=$1 \ + --valid_batch_size=32 \ + --pair_repulsion \ + --distance_transform="Agnesi" \ + --max_num_epochs=100 \ + --patience=40 \ + --amsgrad \ + --seed=1 \ + --clip_grad=100 \ + --keep_checkpoints \ + --restart_latest \ + --save_cpu \ + --config="multihead_config/jz_mp_config_r6.0.yaml" \ + --device=cuda \ + --num_workers=8 \ + --distributed \ + + +# --name="MACE_medium_agnesi_b32_origin_mponly" \ diff --git a/run_multihead_slack.sh b/run_multihead_slack.sh new file mode 100644 index 00000000..f7867013 --- /dev/null +++ b/run_multihead_slack.sh @@ -0,0 +1,46 @@ +DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset +module load pytorch-gpu/py3/2.3.1 +export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" +mace_run_train \ + --name="Test_Multihead_Agnesi" \ + --r_max=6.0 \ + --forces_weight=10 \ + --energy_weight=1 \ + --stress_weight=10 \ + --clip_grad=100 \ + --batch_size=32 \ + --valid_batch_size=128 \ + --num_workers=8 \ + --default_dtype="float64" \ + --seed=124 \ + --loss="universal" \ + --scheduler_patience=5 \ + --config="multihead_config/jz_spice_mp_config_r6.0.yaml" \ + --error_table='PerAtomMAE' \ + --model="MACE" \ + --interaction_first="RealAgnosticInteractionBlock" \ + --interaction="RealAgnosticResidualInteractionBlock" \ + --num_interactions=2 \ + --correlation=3 \ + --max_ell=3 \ + --max_L=1 \ + --num_channels=128 \ + --num_radial_basis=10 \ + --MLP_irreps="16x0e" \ + --scaling='rms_forces_scaling' \ + --lr=0.005 \ + --weight_decay=1e-8 \ + --ema \ + --ema_decay=0.995 \ + --pair_repulsion \ + --distance_transform="Agnesi" \ + --max_num_epochs=250 \ + --patience=40 \ + --amsgrad \ + --device=cuda \ + --clip_grad=100 \ + --keep_checkpoints \ + --restart_latest \ + --distributed \ + --save_cpu + diff --git a/train.slurm b/train.slurm new file mode 100644 index 00000000..61513f66 --- /dev/null +++ b/train.slurm @@ -0,0 +1,73 @@ +#!/bin/bash +#SBATCH --job-name=train-mace # job name +#SBATCH --account=gax@h100 # account +#SBATCH -C h100 # target H100 nodes +# Here, reservation of 3x24=72 CPUs (for 3 tasks) and 3 GPUs (1 GPU per task) on one node only: +#SBATCH --nodes=8 # number of node +#SBATCH --ntasks-per-node=4 # number of MPI tasks per node (here = number of GPUs per node) +#SBATCH --gres=gpu:4 # number of GPUs per node (max 4 for H100 nodes) +# Knowing that here we only reserve one GPU per task (i.e. 1/4 of GPUs), +# the ideal is to reserve 1/4 of CPUs for each task: +#SBATCH --cpus-per-task=8 # number of CPUs per task (here 1/4 of the node) +# /!\ Caution, "multithread" in Slurm vocabulary refers to hyperthreading. +#SBATCH --hint=nomultithread # hyperthreading deactivated +#SBATCH --time=20:00:00 # maximum execution time requested (HH:MM:SS) +#SBATCH --output=n1gpu4-mace-train-%A_%a.out # name of output file +#SBATCH --error=n1gpu4-mace-train-%A_%a.out # name of error file (here, in common with the output file) +#SBATCH --array=0-3%1 # Array index range + +# Access arguments +bs=32 +lr=0.005 +gpu=32 + +# Cleans out modules loaded in interactive and inherited by default +module purge + +# Loading modules +module load pytorch-gpu/py3/2.3.1 + +# Echo of launched commands +set -x + +# set path +export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" +DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset + +# Running code +srun bash run_multihead_mpb_3args.sh ${bs} ${lr} ${gpu} + +#mace_run_train \ +# --name="Test_Multihead_MultiGPU_SpiceMP_MACE" \ +# --model="MACE" \ +# --num_interactions=2 \ +# --num_channels=224 \ +# --max_L=0 \ +# --correlation=3 \ +# --r_max=5.0 \ +# --forces_weight=1000 \ +# --energy_weight=40 \ +# --weight_decay=5e-10 \ +# --clip_grad=1.0 \ +# --batch_size=32 \ +# --valid_batch_size=128 \ +# --max_num_epochs=210 \ +# --patience=50 \ +# --eval_interval=1 \ +# --ema \ +# --num_workers=8 \ +# --error_table='PerAtomMAE' \ +# --default_dtype="float64"\ +# --device=cuda \ +# --seed=123 \ +# --save_cpu \ +# --restart_latest \ +# --loss="weighted" \ +# --scheduler_patience=20 \ +# --lr=0.01 \ +# --swa \ +# --swa_lr=0.00025 \ +# --swa_forces_weight=100 \ +# --start_swa=190 \ +# --config="multihead_config/jz_spice_mp_config.yaml" \ +# --distributed \ diff --git a/train_mponly.slurm b/train_mponly.slurm new file mode 100644 index 00000000..19770f4b --- /dev/null +++ b/train_mponly.slurm @@ -0,0 +1,73 @@ +#!/bin/bash +#SBATCH --job-name=train-mace # job name +#SBATCH --account=gax@h100 # account +#SBATCH -C h100 # target H100 nodes +# Here, reservation of 3x24=72 CPUs (for 3 tasks) and 3 GPUs (1 GPU per task) on one node only: +#SBATCH --nodes=8 # number of node +#SBATCH --ntasks-per-node=4 # number of MPI tasks per node (here = number of GPUs per node) +#SBATCH --gres=gpu:4 # number of GPUs per node (max 4 for H100 nodes) +# Knowing that here we only reserve one GPU per task (i.e. 1/4 of GPUs), +# the ideal is to reserve 1/4 of CPUs for each task: +#SBATCH --cpus-per-task=8 # number of CPUs per task (here 1/4 of the node) +# /!\ Caution, "multithread" in Slurm vocabulary refers to hyperthreading. +#SBATCH --hint=nomultithread # hyperthreading deactivated +#SBATCH --time=20:00:00 # maximum execution time requested (HH:MM:SS) +#SBATCH --output=n1gpu4-mace-train-%A_%a.out # name of output file +#SBATCH --error=n1gpu4-mace-train-%A_%a.out # name of error file (here, in common with the output file) +#SBATCH --array=0-3%1 # Array index range + +# Access arguments +bs=32 +lr=0.005 +gpu=32 + +# Cleans out modules loaded in interactive and inherited by default +module purge + +# Loading modules +module load pytorch-gpu/py3/2.3.1 + +# Echo of launched commands +set -x + +# set path +export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin" +DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset + +# Running code +srun bash run_multihead_mpb_3args_mponly.sh ${bs} ${lr} ${gpu} + +#mace_run_train \ +# --name="Test_Multihead_MultiGPU_SpiceMP_MACE" \ +# --model="MACE" \ +# --num_interactions=2 \ +# --num_channels=224 \ +# --max_L=0 \ +# --correlation=3 \ +# --r_max=5.0 \ +# --forces_weight=1000 \ +# --energy_weight=40 \ +# --weight_decay=5e-10 \ +# --clip_grad=1.0 \ +# --batch_size=32 \ +# --valid_batch_size=128 \ +# --max_num_epochs=210 \ +# --patience=50 \ +# --eval_interval=1 \ +# --ema \ +# --num_workers=8 \ +# --error_table='PerAtomMAE' \ +# --default_dtype="float64"\ +# --device=cuda \ +# --seed=123 \ +# --save_cpu \ +# --restart_latest \ +# --loss="weighted" \ +# --scheduler_patience=20 \ +# --lr=0.01 \ +# --swa \ +# --swa_lr=0.00025 \ +# --swa_forces_weight=100 \ +# --start_swa=190 \ +# --config="multihead_config/jz_spice_mp_config.yaml" \ +# --distributed \