Skip to content

Commit

Permalink
get rid of all stress/n_atoms
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Jun 4, 2024
1 parent e4ac498 commit 0842e7c
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def valid_err_log(
)
elif (
log_errors == "PerAtomRMSEstressvirials"
and eval_metrics["rmse_stress_per_atom"] is not None
and eval_metrics["rmse_stress"] is not None
):
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3
error_stress = eval_metrics["rmse_stress"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress={error_stress:.1f} meV / A^3"
)
elif (
log_errors == "PerAtomRMSEstressvirials"
Expand Down Expand Up @@ -405,7 +405,6 @@ def __init__(self, loss_fn: torch.nn.Module):
"stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
self.add_state("delta_stress", default=[], dist_reduce_fx="cat")
self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat")
self.add_state(
"virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
Expand Down Expand Up @@ -434,10 +433,6 @@ def update(self, batch, output): # pylint: disable=arguments-differ
if output.get("stress") is not None and batch.stress is not None:
self.stress_computed += 1.0
self.delta_stress.append(batch.stress - output["stress"])
self.delta_stress_per_atom.append(
(batch.stress - output["stress"])
/ (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1)
)
if output.get("virials") is not None and batch.virials is not None:
self.virials_computed += 1.0
self.delta_virials.append(batch.virials - output["virials"])
Expand Down Expand Up @@ -480,10 +475,8 @@ def compute(self):
aux["q95_f"] = compute_q95(delta_fs)
if self.stress_computed:
delta_stress = self.convert(self.delta_stress)
delta_stress_per_atom = self.convert(self.delta_stress_per_atom)
aux["mae_stress"] = compute_mae(delta_stress)
aux["rmse_stress"] = compute_rmse(delta_stress)
aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom)
aux["q95_stress"] = compute_q95(delta_stress)
if self.virials_computed:
delta_virials = self.convert(self.delta_virials)
Expand Down

0 comments on commit 0842e7c

Please sign in to comment.