diff --git a/mace/tools/train.py b/mace/tools/train.py index 32d33588..a3f73ff9 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -209,6 +209,7 @@ def train( output_args=output_args, device=device, ) + if rank == 0: valid_err_log( valid_loss, eval_metrics, @@ -216,29 +217,42 @@ def train( log_errors, epoch, ) - - if log_wandb: - wandb_log_dict = { - "epoch": epoch, - "valid_loss": valid_loss, - "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], - "valid_rmse_f": eval_metrics["rmse_f"], - } - wandb.log(wandb_log_dict) - - if valid_loss >= lowest_loss: - patience_counter += 1 - if patience_counter >= patience and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" - ) - epoch = swa.start - elif patience_counter >= patience and epoch >= swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break - if save_all_checkpoints: + if log_wandb: + wandb_log_dict = { + "epoch": epoch, + "valid_loss": valid_loss, + "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], + "valid_rmse_f": eval_metrics["rmse_f"], + } + wandb.log(wandb_log_dict) + + if valid_loss >= lowest_loss: + patience_counter += 1 + if patience_counter >= patience and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" + ) + epoch = swa.start + elif patience_counter >= patience and epoch >= swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + break + if save_all_checkpoints: + param_context = ( + ema.average_parameters() + if ema is not None + else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=True, + ) + else: + lowest_loss = valid_loss + patience_counter = 0 param_context = ( ema.average_parameters() if ema is not None else nullcontext() ) @@ -246,21 +260,9 @@ def train( checkpoint_handler.save( state=CheckpointState(model, optimizer, lr_scheduler), epochs=epoch, - keep_last=True, + keep_last=keep_last, ) - else: - lowest_loss = valid_loss - patience_counter = 0 - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=keep_last, - ) - keep_last = False or save_all_checkpoints + keep_last = False or save_all_checkpoints if distributed: torch.distributed.barrier() epoch += 1