Skip to content

Commit

Permalink
slurm 3 args script
Browse files Browse the repository at this point in the history
  • Loading branch information
birdyLinch committed Aug 11, 2024
1 parent e7044c2 commit b2d282d
Show file tree
Hide file tree
Showing 21 changed files with 753 additions and 23 deletions.
8 changes: 6 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def main() -> None:
head_args.valid_set = data.dataset_from_sharded_hdf5(
head_args.valid_file, r_max=head_args.r_max, z_table=z_table, head=head, heads=list(args.heads.keys()), rank=rank
)

# subset train ratio
if "train_ratio" in head_args.keys():
ratio = head_args.train_ratio
Expand Down Expand Up @@ -428,11 +428,13 @@ def main() -> None:
huber_delta=args.huber_delta,
)
elif args.loss == "universal":
head_stress_mask = torch.Tensor([float('mp' in k) for k in args.heads.keys()]).to(device=device) # TODO: make it general
loss_fn = modules.UniversalLoss(
energy_weight=args.energy_weight,
forces_weight=args.forces_weight,
stress_weight=args.stress_weight,
huber_delta=args.huber_delta,
head_stress_mask=head_stress_mask
)
elif args.loss == "dipole":
assert (
Expand All @@ -458,7 +460,9 @@ def main() -> None:
if args.loss in ("stress", "virials", "huber", "universal"):
compute_virials = True
args.compute_stress = True
args.error_table = "PerAtomRMSEstressvirials"
# args.error_table = "PerAtomRMSEstressvirials"
logging.info(f"Over-wrighting the error table due to the loss setting -> {args.loss} loss")
args.error_table = "PerAtomRMSE+EMAEstressvirials"

output_args = {
"energy": compute_energy,
Expand Down
37 changes: 36 additions & 1 deletion mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,41 @@ def forward(
return self.linear(x) # [n_nodes, 1]


class GroupavgReadoutBlock(torch.nn.Module):
def __init__(self, irreps_in: o3.Irreps,
gate: Optional[Callable],
irrep_out: o3.Irreps=o3.Irreps("0e"),
layered: int=2, # choice from [0, 1]
resolution: int=2, # choice form [0, 1, 2]
):
super().__init__()
self.irreps_in = irreps_in
self.non_linearity = gate
input_size = irreps_in.dim
output_size = irrep_out.dim
hidden_size = 128
self.MLP = torch.nn.Sequential(
torch.nn.Linear(input_size, hidden_size),
torch.nn.BatchNorm1d(hidden_size),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, output_size)
)
self.layered = layered
self.resolution = resolution
self.register_buffer("SO3_grid",
o3.quaternion_to_matrix(
torch.load(f"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_{layered}_{resolution}.pt").to(torch.get_default_dtype())))

def forward(self, x: torch.Tensor, heads: Optional[torch.Tensor] = None):
rand_D = o3.rand_matrix(device=x.device)
gs = self.SO3_grid_2_2 @ rand_D # [72, 3, 3]
Ds = self.irreps_in.D_from_matrix(gs) # [72, D, D]

xs = torch.einsum("nd,rjd->nrj", x, Ds) # [n_graphs, D], [72, D, D] -> [n_graphs, 72, D]
outs = self.MLP(xs.view(-1, xs.size(-1))) # [n_graph, 72, 1]
out = torch.mean(outs.view(*xs.shape[:-1], -1), dim=1, keepdim=False)
return out

@simplify_if_compile
@compile_mode("script")
class NonLinearReadoutBlock(torch.nn.Module):
Expand All @@ -80,7 +115,7 @@ def forward(
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.non_linearity(self.linear_1(x))
if hasattr(self, "num_heads") and self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
x = mask_head(x, heads, self.num_heads) # decorrelate two mlps
return self.linear_2(x) # [n_nodes, len(heads)]


Expand Down
28 changes: 20 additions & 8 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ def __repr__(self):

class UniversalLoss(torch.nn.Module):
def __init__(
self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01
self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01,
head_stress_mask=None
) -> None:
super().__init__()
self.huber_delta = huber_delta
Expand All @@ -270,16 +271,27 @@ def __init__(
"stress_weight",
torch.tensor(stress_weight, dtype=torch.get_default_dtype()),
)
self.head_stress_mask=head_stress_mask

def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
num_atoms = ref.ptr[1:] - ref.ptr[:-1]
return (
self.energy_weight
* self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms)
+ self.forces_weight
* conditional_huber_forces(ref, pred, huber_delta=self.huber_delta)
+ self.stress_weight * self.huber_loss(ref["stress"], pred["stress"])
)
if self.head_stress_mask is None:
return (
self.energy_weight
* self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms)
+ self.forces_weight
* conditional_huber_forces(ref, pred, huber_delta=self.huber_delta)
+ self.stress_weight * self.huber_loss(ref["stress"], pred["stress"])
)
else:
stress_musk = self.head_stress_mask[ref.head].view(-1, 1, 1)
return (
self.energy_weight
* self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms)
+ self.forces_weight
* conditional_huber_forces(ref, pred, huber_delta=self.huber_delta)
+ self.stress_weight * self.huber_loss(ref["stress"] * stress_musk, pred["stress"] * stress_musk)
)

def __repr__(self):
return (
Expand Down
1 change: 1 addition & 0 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def forward(
# Interactions
node_es_list = [pair_node_energy]
node_feats_list = []
# import ipdb; ipdb.set_trace()
for interaction, product, readout in zip(
self.interactions, self.products, self.readouts
):
Expand Down
126 changes: 126 additions & 0 deletions mace/modules/test_grpavg_readout.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [],
"source": [
"from abc import abstractmethod\n",
"from typing import Callable, List, Optional, Tuple, Union\n",
"\n",
"import numpy as np\n",
"from torch.nn.functional import silu\n",
"from e3nn import nn, o3\n",
"from e3nn.util.jit import compile_mode\n",
"import torch\n",
"# Set the default floating-point type to float64\n",
"torch.set_default_dtype(torch.float64)\n",
"\n",
"class GroupavgReadoutBlock(torch.nn.Module):\n",
"\n",
" def __init__(self, irreps_in: o3.Irreps,\n",
" gate: Optional[Callable],\n",
" irrep_out: o3.Irreps=o3.Irreps(\"0e\"),\n",
" ):\n",
" super().__init__()\n",
" self.irreps_in = irreps_in\n",
" self.non_linearity = gate\n",
" input_size = irreps_in.dim\n",
" output_size = irrep_out.dim\n",
" hidden_size = 128\n",
" self.MLP = torch.nn.Sequential(\n",
" torch.nn.Linear(input_size, hidden_size),\n",
" torch.nn.BatchNorm1d(hidden_size),\n",
" torch.nn.SiLU(),\n",
" torch.nn.Linear(hidden_size, output_size)\n",
" )\n",
" self.register_buffer(\"SO3_grid_1_0\", \n",
" o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_0.pt\").to(torch.get_default_dtype())))\n",
" self.register_buffer(\"SO3_grid_1_1\", \n",
" o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_1.pt\").to(torch.get_default_dtype())))\n",
" self.register_buffer(\"SO3_grid_1_2\", \n",
" o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_1_2.pt\").to(torch.get_default_dtype())))\n",
" self.register_buffer(\"SO3_grid_2_0\", \n",
" o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_0.pt\").to(torch.get_default_dtype())))\n",
" self.register_buffer(\"SO3_grid_2_1\", \n",
" o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_1.pt\").to(torch.get_default_dtype())))\n",
" self.register_buffer(\"SO3_grid_2_2\", \n",
" o3.quaternion_to_matrix(torch.load(\"/lustre/fsn1/projects/rech/gax/unh55hx/misc/SO3_grid/SO3_grid_2_2.pt\").to(torch.get_default_dtype())))\n",
"\n",
"\n",
" def forward(self, x: torch.Tensor, heads: Optional[torch.Tensor] = None):\n",
" rand_D = o3.rand_matrix(device=x.device)\n",
" gs = self.SO3_grid_1_2 @ rand_D # [72, 3, 3]\n",
" Ds = self.irreps_in.D_from_matrix(gs) # [72, D, D]\n",
"\n",
" xs = torch.einsum(\"nd,rjd->nrj\", x, Ds) # [n_graphs, D], [72, D, D] -> [n_graphs, 72, D]\n",
" print(xs.shape)\n",
" outs = self.MLP(xs.view(-1, xs.size(-1))) # [n_graph, 72, 1]\n",
" out = torch.mean(outs.view(*xs.shape[:-1], -1), dim=1, keepdim=False)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 189,
"metadata": {},
"outputs": [],
"source": [
"irreps_in = o3.Irreps(\"3x0e+1x1o+1x2e\")\n",
"n_graph = 32\n",
"readout = GroupavgReadoutBlock(irreps_in=irreps_in, gate=torch.nn.SiLU)"
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([32, 4608, 11])\n",
"torch.Size([32, 4608, 11])\n",
"tensor(0.0003, grad_fn=<MeanBackward0>)\n"
]
}
],
"source": [
"x = irreps_in.randn(n_graph, -1)\n",
"\n",
"out = readout(x)\n",
"\n",
"rot_x = x @ irreps_in.D_from_matrix(o3.rand_matrix())\n",
"\n",
"rot_out = readout(rot_x)\n",
"# print(x - rot_x)\n",
"print((rot_out - out).abs().mean())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 2 additions & 2 deletions mace/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def compute_mean_rms_energy_forces(
forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], }
head = torch.cat(head_list, dim=0) # [total_n_graphs]
head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs]

mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1))
mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0))
rms = to_numpy(
torch.sqrt(
scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1)
Expand Down
15 changes: 15 additions & 0 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ def valid_err_log(
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress={error_stress:.1f} meV / A^3"
)
elif (
log_errors == "PerAtomRMSE+EMAEstressvirials" and eval_metrics["rmse_stress"] is not None
):
error_e_rmse = eval_metrics["rmse_e_per_atom"] * 1e3
error_f_rmse = eval_metrics["rmse_f"] * 1e3
error_stress_rmse = eval_metrics["rmse_stress"] * 1e3
error_e_mae = eval_metrics["mae_e_per_atom"] * 1e3
error_f_mae = eval_metrics["mae_f"] * 1e3
error_stress_mae = eval_metrics["mae_stress"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, \t RMSE_E_per_atom={error_e_rmse:.1f} meV, RMSE_F={error_f_rmse:.1f} meV / A, RMSE_stress={error_stress_rmse:.1f} meV / A^3"
)
logging.info(
f" \t MAE_E_per_atom={error_e_mae:.1f} meV, MAE_F={error_f_mae:.1f} meV / A, MAE_stress={error_stress_mae:.1f} meV / A^3"
)
elif (
log_errors == "PerAtomRMSEstressvirials"
and eval_metrics["rmse_virials_per_atom"] is not None
Expand Down
15 changes: 15 additions & 0 deletions multihead_config/jz_mp_config_r6.0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
avg_num_neighbor_head: mp_pbe
device: cuda
multi_processed_test: True
heads:
mp_pbe:
train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/MatProj
valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/MatProj
E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json
config_type_weights:
Default: 1.0
avg_num_neighbors: 61.9649349317854
mean: 0.1634233391135065
std: 0.7735790334431056
#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice
#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json
25 changes: 25 additions & 0 deletions multihead_config/jz_oc_mp_config_r6.0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
avg_num_neighbor_head: mp_pbe
device: cuda
multi_processed_test: True
heads:
spice_wB97M:
train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/spice
valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/spice
E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json
config_type_weights:
Default: 1.0
avg_num_neighbors: 22.86736849018836
mean: -4.406405198254238
std: 1.0737544472166278

mp_pbe:
train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/MatProj
valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/MatProj
E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json
config_type_weights:
Default: 1.0
avg_num_neighbors: 61.9649349317854
mean: 0.1634233391135065
std: 0.7735790334431056
#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice
#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json
7 changes: 5 additions & 2 deletions multihead_config/jz_spice_mp_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,8 @@ heads:
avg_num_neighbors: 35.985167534166
mean: -4.48071865
std: 0.77357903
#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice
#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json

# test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice
# statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json
# mean, std does not depend on r
# no online, compute statistics script with the same yaml.
25 changes: 25 additions & 0 deletions multihead_config/jz_spice_mp_config_r6.0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
avg_num_neighbor_head: mp_pbe
device: cuda
multi_processed_test: True
heads:
spice_wB97M:
train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/spice
valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/spice
E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json
config_type_weights:
Default: 1.0
avg_num_neighbors: 22.86736849018836
mean: -4.406405198254238
std: 1.0737544472166278

mp_pbe:
train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/MatProj
valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/MatProj
E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json
config_type_weights:
Default: 1.0
avg_num_neighbors: 61.9649349317854
mean: 0.1634233391135065
std: 0.7735790334431056
#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice
#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json
Loading

0 comments on commit b2d282d

Please sign in to comment.