Skip to content

Commit

Permalink
more digits in training validation err log
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Jun 7, 2024
1 parent c8760d9 commit 9cef22a
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def valid_err_log(
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.8f}, RMSE_E_per_atom={error_e:.2f} meV, RMSE_F={error_f:.1f} meV / A"
)
elif (
log_errors == "PerAtomRMSEstressvirials"
Expand All @@ -66,7 +66,7 @@ def valid_err_log(
error_f = eval_metrics["rmse_f"] * 1e3
error_stress = eval_metrics["rmse_stress"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress={error_stress:.1f} meV / A^3"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.8f}, RMSE_E_per_atom={error_e:.2f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress={error_stress:.2f} meV / A^3"
)
elif (
log_errors == "PerAtomRMSEstressvirials"
Expand All @@ -76,37 +76,37 @@ def valid_err_log(
error_f = eval_metrics["rmse_f"] * 1e3
error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.8f}, RMSE_E_per_atom={error_e:.2f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV"
)
elif log_errors == "TotalRMSE":
error_e = eval_metrics["rmse_e"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.8f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A"
)
elif log_errors == "PerAtomMAE":
error_e = eval_metrics["mae_e_per_atom"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.8f}, MAE_E_per_atom={error_e:.2f} meV, MAE_F={error_f:.1f} meV / A"
)
elif log_errors == "TotalMAE":
error_e = eval_metrics["mae_e"] * 1e3
error_f = eval_metrics["mae_f"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.8f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A"
)
elif log_errors == "DipoleRMSE":
error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.8f}, RMSE_MU_per_atom={error_mu:.2f} mDebye"
)
elif log_errors == "EnergyDipoleRMSE":
error_e = eval_metrics["rmse_e_per_atom"] * 1e3
error_f = eval_metrics["rmse_f"] * 1e3
error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3
logging.info(
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye"
f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.8f}, RMSE_E_per_atom={error_e:.2f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye"
)


Expand Down

0 comments on commit 9cef22a

Please sign in to comment.