diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..cf8dae6c --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include mace/py.typed diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 5c198ec4..de34b1d4 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -155,6 +155,7 @@ def run(args: argparse.Namespace): # Data preparation collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=args.train_file, valid_path=args.valid_file, valid_fraction=args.valid_fraction, @@ -211,7 +212,7 @@ def run(args: argparse.Namespace): atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) - logging.info(f"Atomic energies: {atomic_energies.tolist()}") + logging.info(f"Atomic Energies: {atomic_energies.tolist()}") _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] avg_num_neighbors, mean, std=pool_compute_stats(_inputs) logging.info(f"Average number of neighbors: {avg_num_neighbors}") diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 725c5d21..f98c7a04 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -55,6 +55,7 @@ def run(args: argparse.Namespace) -> None: """ This script runs the training/fine tuning for mace """ + args, input_log_messages = tools.check_args(args) tag = tools.get_tag(name=args.name, seed=args.seed) if args.distributed: try: @@ -74,6 +75,9 @@ def run(args: argparse.Namespace) -> None: # Setup tools.set_seeds(args.seed) tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) + logging.info("===========VERIFYING SETTINGS===========") + for message, loglevel in input_log_messages: + logging.log(level=loglevel, msg=message) if args.distributed: torch.cuda.set_device(local_rank) @@ -84,7 +88,7 @@ def run(args: argparse.Namespace) -> None: logging.info(f"MACE version: {mace.__version__}") except AttributeError: logging.info("Cannot find MACE version, please install MACE via pip") - logging.info(f"Configuration: {args}") + logging.debug(f"Configuration: {args}") tools.set_default_dtype(args.default_dtype) device = tools.init_device(args.device) @@ -132,6 +136,8 @@ def run(args: argparse.Namespace) -> None: args.compute_avg_num_neighbors = False args.E0s = statistics["atomic_energies"] + logging.info("") + logging.info("===========LOADING INPUT DATA===========") # Data preparation if args.train_file.endswith(".xyz"): if args.valid_file is not None: @@ -140,6 +146,7 @@ def run(args: argparse.Namespace) -> None: ), "valid_file if given must be same format as train_file" config_type_weights = get_config_type_weights(args.config_type_weights) collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=args.train_file, valid_path=args.valid_file, valid_fraction=args.valid_fraction, @@ -154,11 +161,16 @@ def run(args: argparse.Namespace) -> None: charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, ) + if len(collections.train) < args.batch_size: + logging.error( + f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" + ) + if len(collections.valid) < args.valid_batch_size: + logging.warning( + f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" + ) + args.valid_batch_size = len(collections.valid) - logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " - f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]" - ) else: atomic_energies_dict = None @@ -181,12 +193,11 @@ def run(args: argparse.Namespace) -> None: assert isinstance(zs_list, list) z_table = tools.get_atomic_number_table_from_zs(zs_list) # yapf: enable - logging.info(z_table) + logging.info(f"Atomic Numbers used: {z_table.zs}") if atomic_energies_dict is None or len(atomic_energies_dict) == 0: if args.E0s.lower() == "foundation": assert args.foundation_model is not None - logging.info("Using atomic energies from foundation model") z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) @@ -196,6 +207,9 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } + logging.info( + f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" + ) else: if args.train_file.endswith(".xyz"): atomic_energies_dict = get_atomic_energies( @@ -226,7 +240,9 @@ def run(args: argparse.Namespace) -> None: atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) - logging.info(f"Atomic energies: {atomic_energies.tolist()}") + logging.info( + f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}" + ) if args.train_file.endswith(".xyz"): train_set = [ @@ -286,7 +302,8 @@ def run(args: argparse.Namespace) -> None: num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), ) - + logging.info("") + logging.info("===========MODEL DETAILS===========") if args.loss == "weighted": loss_fn = modules.WeightedEnergyForcesLoss( energy_weight=args.energy_weight, forces_weight=args.forces_weight @@ -336,7 +353,6 @@ def run(args: argparse.Namespace) -> None: else: # Unweighted Energy and Forces loss by default loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) - logging.info(loss_fn) if args.compute_avg_num_neighbors: avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) @@ -350,14 +366,22 @@ def run(args: argparse.Namespace) -> None: args.avg_num_neighbors = (num_neighbors / num_graphs).item() else: args.avg_num_neighbors = avg_num_neighbors - logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") + if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: + logging.warning( + f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}" + ) + else: + logging.info(f"Average number of neighbors: {args.avg_num_neighbors:.1f}") # Selecting outputs compute_virials = False if args.loss in ("stress", "virials", "huber", "universal"): compute_virials = True args.compute_stress = True - args.error_table = "PerAtomRMSEstressvirials" + if "MAE" in args.error_table: + args.error_table = "PerAtomMAEstressvirials" + else: + args.error_table = "PerAtomRMSEstressvirials" output_args = { "energy": compute_energy, @@ -366,7 +390,10 @@ def run(args: argparse.Namespace) -> None: "stress": args.compute_stress, "dipoles": compute_dipole, } - logging.info(f"Selected the following outputs: {output_args}") + + logging.info( + f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" + ) if args.scaling == "no_scaling": args.std = 1.0 @@ -377,11 +404,14 @@ def run(args: argparse.Namespace) -> None: ) # Build model if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Building model") + logging.info("Loading FOUNDATION model") model_config_foundation = extract_config_mace_model(model_foundation) model_config_foundation["atomic_numbers"] = z_table.zs model_config_foundation["num_elements"] = len(z_table) args.max_L = model_config_foundation["hidden_irreps"].lmax + args.num_channels = list( + {irrep.mul for irrep in o3.Irreps(model_config_foundation["hidden_irreps"])} + )[0] model_config_foundation["atomic_inter_shift"] = ( model_foundation.scale_shift.shift.item() ) @@ -391,23 +421,35 @@ def run(args: argparse.Namespace) -> None: model_config_foundation["atomic_energies"] = atomic_energies args.model = "FoundationMACE" model_config = model_config_foundation # pylint + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({model_config_foundation['hidden_irreps']})" + ) + logging.info( + f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" + ) + logging.info( + f"Radial cutoff: {model_config_foundation['r_max']} Å (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} Å)" + ) + logging.info( + f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" + ) else: logging.info("Building model") - if args.num_channels is not None and args.max_L is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - assert args.max_L >= 0, "max_L must be non-negative integer" - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - - logging.info(f"Hidden irreps: {args.hidden_irreps}") - + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" + ) + logging.info( + f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" + ) + logging.info( + f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + ) + logging.info( + f"Radial cutoff: {args.r_max} Å (total receptive field for each atom: {args.r_max * args.num_interactions} Å)" + ) + logging.info( + f"Distance transform for radial basis functions: {args.distance_transform}" + ) model_config = dict( r_max=args.r_max, num_bessel=args.num_radial_basis, @@ -519,6 +561,20 @@ def run(args: argparse.Namespace) -> None: ) model.to(device) + logging.debug(model) + logging.info(f"Total number of parameters: {tools.count_parameters(model)}") + logging.info("") + logging.info("===========OPTIMIZER INFORMATION===========") + logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") + logging.info(f"Batch size: {args.batch_size}") + if args.ema: + logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") + logging.info( + f"Number of gradient updates: {int(args.max_num_epochs*len(collections.train)/args.batch_size)}" + ) + logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") + logging.info(loss_fn) + # Optimizer decay_interactions = {} no_decay_interactions = {} @@ -589,13 +645,9 @@ def run(args: argparse.Namespace) -> None: swas.append(True) if args.start_swa is None: args.start_swa = max(1, args.max_num_epochs // 4 * 3) - else: - if args.start_swa > args.max_num_epochs: - logging.info( - f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" - ) - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - logging.info(f"Setting start Stage Two to {args.start_swa}") + logging.info( + f"Stage Two will start after {args.start_swa} epochs with loss function:" + ) if args.loss == "forces_only": raise ValueError("Can not select Stage Two with forces only loss.") if args.loss == "virials": @@ -616,17 +668,12 @@ def run(args: argparse.Namespace) -> None: forces_weight=args.swa_forces_weight, dipole_weight=args.swa_dipole_weight, ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" - ) else: loss_fn_energy = modules.WeightedEnergyForcesLoss( energy_weight=args.swa_energy_weight, forces_weight=args.swa_forces_weight, ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" - ) + logging.info(loss_fn_energy) swa = tools.SWAContainer( model=AveragedModel(model), scheduler=SWALR( @@ -670,10 +717,6 @@ def run(args: argparse.Namespace) -> None: for group in optimizer.param_groups: group["lr"] = args.lr - logging.info(model) - logging.info(f"Number of parameters: {tools.count_parameters(model)}") - logging.info(f"Optimizer: {optimizer}") - if args.wandb: logging.info("Using Weights and Biases for logging") import wandb @@ -723,7 +766,8 @@ def run(args: argparse.Namespace) -> None: train_sampler=train_sampler, rank=rank, ) - + logging.info("") + logging.info("===========RESULTS===========") logging.info("Computing metrics for training, validation, and test sets") all_data_loaders = { @@ -778,6 +822,13 @@ def run(args: argparse.Namespace) -> None: ) all_data_loaders[test_name] = test_loader + train_valid_data_loader = { + k: v for k, v in all_data_loaders.items() if k in ["train", "valid"] + } + test_data_loader = { + k: v for k, v in all_data_loaders.items() if k not in ["train", "valid"] + } + for swa_eval in swas: epoch = checkpoint_handler.load_latest( state=tools.CheckpointState(model, optimizer, lr_scheduler), @@ -788,13 +839,27 @@ def run(args: argparse.Namespace) -> None: if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) model_to_evaluate = model if not args.distributed else distributed_model - logging.info(f"Loaded model from epoch {epoch}") + if swa_eval: + logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") + else: + logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") for param in model.parameters(): param.requires_grad = False - table = create_error_table( + + table_train = create_error_table( + table_type=args.error_table, + all_data_loaders=train_valid_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + table_test = create_error_table( table_type=args.error_table, - all_data_loaders=all_data_loaders, + all_data_loaders=test_data_loader, model=model_to_evaluate, loss_fn=loss_fn, output_args=output_args, @@ -802,7 +867,8 @@ def run(args: argparse.Namespace) -> None: device=device, distributed=args.distributed, ) - logging.info("\n" + str(table)) + logging.info("Error-table on TRAIN and VALID:\n" + str(table_train)) + logging.info("Error-table on TEST:\n" + str(table_test)) if rank == 0: # Save entire model @@ -821,7 +887,9 @@ def run(args: argparse.Namespace) -> None: ), } if swa_eval: - torch.save(model, Path(args.model_dir) / (args.name + "_stagetwo.model")) + torch.save( + model, Path(args.model_dir) / (args.name + "_stagetwo.model") + ) try: path_complied = Path(args.model_dir) / ( args.name + "_stagetwo_compiled.model" diff --git a/mace/data/utils.py b/mace/data/utils.py index c870d6ed..78e3e76f 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -53,7 +53,7 @@ class Configuration: def random_train_valid_split( - items: Sequence, valid_fraction: float, seed: int + items: Sequence, valid_fraction: float, seed: int, work_dir: str ) -> Tuple[List, List]: assert 0.0 < valid_fraction < 1.0 @@ -63,6 +63,19 @@ def random_train_valid_split( indices = list(range(size)) rng = np.random.default_rng(seed) rng.shuffle(indices) + if len(indices[train_size:]) < 10: + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" + ) + else: + # Save indices to file + with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: + for index in indices[train_size:]: + f.write(f"{index}\n") + + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" + ) return ( [items[i] for i in indices[:train_size]], @@ -72,12 +85,12 @@ def random_train_valid_split( def config_from_atoms_list( atoms_list: List[ase.Atoms], - energy_key="energy", - forces_key="forces", - stress_key="stress", - virials_key="virials", - dipole_key="dipole", - charges_key="charges", + energy_key="REF_energy", + forces_key="REF_forces", + stress_key="REF_stress", + virials_key="REF_virials", + dipole_key="REF_dipole", + charges_key="REF_charges", config_type_weights: Dict[str, float] = None, ) -> Configurations: """Convert list of ase.Atoms into Configurations""" @@ -103,12 +116,12 @@ def config_from_atoms_list( def config_from_atoms( atoms: ase.Atoms, - energy_key="energy", - forces_key="forces", - stress_key="stress", - virials_key="virials", - dipole_key="dipole", - charges_key="charges", + energy_key="REF_energy", + forces_key="REF_forces", + stress_key="REF_stress", + virials_key="REF_virials", + dipole_key="REF_dipole", + charges_key="REF_charges", config_type_weights: Dict[str, float] = None, ) -> Configuration: """Convert ase.Atoms to Configuration""" @@ -192,41 +205,41 @@ def test_config_types( def load_from_xyz( file_path: str, config_type_weights: Dict, - energy_key: str = "energy", - forces_key: str = "forces", - stress_key: str = "stress", - virials_key: str = "virials", - dipole_key: str = "dipole", - charges_key: str = "charges", + energy_key: str = "REF_energy", + forces_key: str = "REF_forces", + stress_key: str = "REF_stress", + virials_key: str = "REF_virials", + dipole_key: str = "REF_dipole", + charges_key: str = "REF_charges", extract_atomic_energies: bool = False, keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") if energy_key == "energy": - logging.info( - "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_energy'. You need to use --energy_key='REF_energy', to tell the key name chosen." + logging.warning( + "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." ) energy_key = "REF_energy" for atoms in atoms_list: try: atoms.info["REF_energy"] = atoms.get_potential_energy() except Exception as e: # pylint: disable=W0703 - logging.warning(f"Failed to extract energy: {e}") + logging.error(f"Failed to extract energy: {e}") atoms.info["REF_energy"] = None if forces_key == "forces": - logging.info( - "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_forces'. You need to use --forces_key='REF_forces', to tell the key name chosen." + logging.warning( + "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." ) forces_key = "REF_forces" for atoms in atoms_list: try: atoms.arrays["REF_forces"] = atoms.get_forces() except Exception as e: # pylint: disable=W0703 - logging.warning(f"Failed to extract forces: {e}") + logging.error(f"Failed to extract forces: {e}") atoms.arrays["REF_forces"] = None if stress_key == "stress": - logging.info( - "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_stress'. You need to use --stress_key='REF_stress', to tell the key name chosen." + logging.warning( + "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." ) stress_key = "REF_stress" for atoms in atoms_list: @@ -298,7 +311,7 @@ def compute_average_E0s( for i, z in enumerate(z_table.zs): atomic_energies_dict[z] = E0s[i] except np.linalg.LinAlgError: - logging.warning( + logging.error( "Failed to compute E0s using least squares regression, using the same for all atoms" ) atomic_energies_dict = {} diff --git a/mace/py.typed b/mace/py.typed new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/mace/py.typed @@ -0,0 +1 @@ + diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 80375590..54c59455 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -1,4 +1,5 @@ from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +from .arg_parser_tools import check_args from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState from .finetuning_utils import load_foundations @@ -39,6 +40,7 @@ "to_numpy", "to_one_hot", "build_default_arg_parser", + "check_args", "set_seeds", "init_device", "setup_logger", diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 37d0ce8d..2b0e2b56 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -31,22 +31,28 @@ def build_default_arg_parser() -> argparse.ArgumentParser: # Directories parser.add_argument( - "--log_dir", help="directory for log files", type=str, default="logs" + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) + parser.add_argument( + "--log_dir", help="directory for log files", type=str, default=None ) parser.add_argument( - "--model_dir", help="directory for final model", type=str, default="." + "--model_dir", help="directory for final model", type=str, default=None ) parser.add_argument( "--checkpoints_dir", help="directory for checkpoint files", type=str, - default="checkpoints", + default=None, ) parser.add_argument( - "--results_dir", help="directory for results", type=str, default="results" + "--results_dir", help="directory for results", type=str, default=None ) parser.add_argument( - "--downloads_dir", help="directory for downloads", type=str, default="downloads" + "--downloads_dir", help="directory for downloads", type=str, default=None ) # Device and logging @@ -80,6 +86,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "PerAtomRMSE", "TotalRMSE", "PerAtomRMSEstressvirials", + "PerAtomMAEstressvirials", "PerAtomMAE", "TotalMAE", "DipoleRMSE", @@ -127,7 +134,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--pair_repulsion", - help="use amsgrad variant of optimizer", + help="use pair repulsion term with ZBL potential", action="store_true", default=False, ) @@ -183,7 +190,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--hidden_irreps", help="irreps for hidden node states", type=str, - default="128x0e + 128x1o", + default=None, ) # add option to specify irreps by channel number and max L parser.add_argument( @@ -334,37 +341,37 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--energy_key", help="Key of reference energies in training xyz", type=str, - default="energy", + default="REF_energy", ) parser.add_argument( "--forces_key", help="Key of reference forces in training xyz", type=str, - default="forces", + default="REF_forces", ) parser.add_argument( "--virials_key", help="Key of reference virials in training xyz", type=str, - default="virials", + default="REF_virials", ) parser.add_argument( "--stress_key", help="Key of reference stress in training xyz", type=str, - default="stress", + default="REF_stress", ) parser.add_argument( "--dipole_key", help="Key of reference dipoles in training xyz", type=str, - default="dipole", + default="REF_dipole", ) parser.add_argument( "--charges_key", help="Key of atomic charges in training xyz", type=str, - default="charges", + default="REF_charges", ) # Loss and optimization @@ -388,7 +395,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--forces_weight", help="weight of forces loss", type=float, default=100.0 ) parser.add_argument( - "--swa_forces_weight","--stage_two_forces_weight", + "--swa_forces_weight", + "--stage_two_forces_weight", help="weight of forces loss after starting Stage Two (previously called swa)", type=float, default=100.0, @@ -398,7 +406,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--energy_weight", help="weight of energy loss", type=float, default=1.0 ) parser.add_argument( - "--swa_energy_weight","--stage_two_energy_weight", + "--swa_energy_weight", + "--stage_two_energy_weight", help="weight of energy loss after starting Stage Two (previously called swa)", type=float, default=1000.0, @@ -408,7 +417,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--virials_weight", help="weight of virials loss", type=float, default=1.0 ) parser.add_argument( - "--swa_virials_weight", "--stage_two_virials_weight", + "--swa_virials_weight", + "--stage_two_virials_weight", help="weight of virials loss after starting Stage Two (previously called swa)", type=float, default=10.0, @@ -418,7 +428,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--stress_weight", help="weight of virials loss", type=float, default=1.0 ) parser.add_argument( - "--swa_stress_weight", "--stage_two_stress_weight", + "--swa_stress_weight", + "--stage_two_stress_weight", help="weight of stress loss after starting Stage Two (previously called swa)", type=float, default=10.0, @@ -428,7 +439,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 ) parser.add_argument( - "--swa_dipole_weight","--stage_two_dipole_weight", + "--swa_dipole_weight", + "--stage_two_dipole_weight", help="weight of dipoles after starting Stage Two (previously called swa)", type=float, default=1.0, @@ -467,7 +479,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--lr", help="Learning rate of optimizer", type=float, default=0.01 ) parser.add_argument( - "--swa_lr", "--stage_two_lr", help="Learning rate of optimizer in Stage Two (previously called swa)", type=float, default=1e-3, dest="swa_lr" + "--swa_lr", + "--stage_two_lr", + help="Learning rate of optimizer in Stage Two (previously called swa)", + type=float, + default=1e-3, + dest="swa_lr", ) parser.add_argument( "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 @@ -494,14 +511,16 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=0.9993, ) parser.add_argument( - "--swa", "--stage_two", + "--swa", + "--stage_two", help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", action="store_true", default=False, dest="swa", ) parser.add_argument( - "--start_swa","--start_stage_two", + "--start_swa", + "--start_stage_two", help="Number of epochs before changing to Stage Two loss weights", type=int, default=None, @@ -541,7 +560,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=True, ) parser.add_argument( - "--eval_interval", help="evaluate model every epochs", type=int, default=2 + "--eval_interval", help="evaluate model every epochs", type=int, default=1 ) parser.add_argument( "--keep_checkpoints", @@ -681,37 +700,37 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: "--energy_key", help="Key of reference energies in training xyz", type=str, - default="energy", + default="REF_energy", ) parser.add_argument( "--forces_key", help="Key of reference forces in training xyz", type=str, - default="forces", + default="REF_forces", ) parser.add_argument( "--virials_key", help="Key of reference virials in training xyz", type=str, - default="virials", + default="REF_virials", ) parser.add_argument( "--stress_key", help="Key of reference stress in training xyz", type=str, - default="stress", + default="REF_stress", ) parser.add_argument( "--dipole_key", help="Key of reference dipoles in training xyz", type=str, - default="dipole", + default="REF_dipole", ) parser.add_argument( "--charges_key", help="Key of atomic charges in training xyz", type=str, - default="charges", + default="REF_charges", ) parser.add_argument( "--atomic_numbers", diff --git a/mace/tools/arg_parser_tools.py b/mace/tools/arg_parser_tools.py new file mode 100644 index 00000000..da64806a --- /dev/null +++ b/mace/tools/arg_parser_tools.py @@ -0,0 +1,113 @@ +import logging +import os + +from e3nn import o3 + + +def check_args(args): + """ + Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing + the (potentially) modified args and a list of log messages. + """ + log_messages = [] + + # Directories + # Use work_dir for all other directories as well, unless they were specified by the user + if args.log_dir is None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir is None: + args.model_dir = args.work_dir + if args.checkpoints_dir is None: + args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") + if args.results_dir is None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir is None: + args.downloads_dir = os.path.join(args.work_dir, "downloads") + + # Model + # Check if hidden_irreps, num_channels and max_L are consistent + if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif ( + args.hidden_irreps is not None + and args.num_channels is not None + and args.max_L is not None + ): + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + log_messages.append( + ( + "All of hidden_irreps, num_channels and max_L are specified", + logging.WARNING, + ) + ) + log_messages.append( + ( + f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", + logging.WARNING, + ) + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.num_channels is not None and args.max_L is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + assert args.max_L >= 0, "max_L must be non-negative integer" + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.hidden_irreps is not None: + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + args.num_channels = list( + {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} + )[0] + args.max_L = o3.Irreps(args.hidden_irreps).lmax + elif args.max_L is not None and args.num_channels is None: + assert args.max_L >= 0, "max_L must be non-negative integer" + args.num_channels = 128 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + elif args.max_L is None and args.num_channels is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + args.max_L = 1 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + + # Loss and optimization + # Check Stage Two loss start + if args.swa: + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + if args.start_swa > args.max_num_epochs: + log_messages.append( + ( + f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + logging.WARNING, + ) + ) + log_messages.append( + ( + "Stage Two will not start, as start_stage_two > max_num_epochs", + logging.WARNING, + ) + ) + args.swa = False + + return args, log_messages diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index cb0a0bc8..a353b447 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -29,6 +29,7 @@ class SubsetCollection: def get_dataset_from_xyz( + work_dir: str, train_path: str, valid_path: str, valid_fraction: float, @@ -36,9 +37,9 @@ def get_dataset_from_xyz( test_path: str = None, seed: int = 1234, keep_isolated_atoms: bool = False, - energy_key: str = "energy", - forces_key: str = "forces", - stress_key: str = "stress", + energy_key: str = "REF_energy", + forces_key: str = "REF_forces", + stress_key: str = "REF_stress", virials_key: str = "virials", dipole_key: str = "dipoles", charges_key: str = "charges", @@ -57,7 +58,7 @@ def get_dataset_from_xyz( keep_isolated_atoms=keep_isolated_atoms, ) logging.info( - f"Loaded {len(all_train_configs)} training configurations from '{train_path}'" + f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" ) if valid_path is not None: _, valid_configs = data.load_from_xyz( @@ -72,15 +73,15 @@ def get_dataset_from_xyz( extract_atomic_energies=False, ) logging.info( - f"Loaded {len(valid_configs)} validation configurations from '{valid_path}'" + f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" ) train_configs = all_train_configs else: - logging.info( - "Using random %s%% of training set for validation", 100 * valid_fraction - ) train_configs, valid_configs = data.random_train_valid_split( - all_train_configs, valid_fraction, seed + all_train_configs, valid_fraction, seed, work_dir + ) + logging.info( + f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" ) test_configs = [] @@ -99,8 +100,13 @@ def get_dataset_from_xyz( # create list of tuples (config_type, list(Atoms)) test_configs = data.test_config_types(all_test_configs) logging.info( - f"Loaded {len(all_test_configs)} test configurations from '{test_path}'" + f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" ) + for name, tmp_configs in test_configs: + logging.info( + f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" + ) + return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), atomic_energies_dict, @@ -128,10 +134,10 @@ def print_git_commit(): repo = git.Repo(search_parent_directories=True) commit = repo.head.commit.hexsha - logging.info(f"Current Git commit: {commit}") + logging.debug(f"Current Git commit: {commit}") return commit except Exception as e: # pylint: disable=W0703 - logging.info(f"Error accessing Git repository: {e}") + logging.debug(f"Error accessing Git repository: {e}") return "None" @@ -284,7 +290,7 @@ def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: def get_atomic_energies(E0s, train_collection, z_table) -> dict: if E0s is not None: logging.info( - "Atomic Energies not in training file, using command line argument E0s" + "Isolated Atomic Energies (E0s) not in training file, using command line argument" ) if E0s.lower() == "average": logging.info( @@ -301,11 +307,18 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: f"Could not compute average E0s if no training xyz given, error {e} occured" ) from e else: - try: - atomic_energies_dict = ast.literal_eval(E0s) - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError(f"E0s specified invalidly, error {e} occured") from e + if E0s.endswith(".json"): + logging.info(f"Loading atomic energies from {E0s}") + with open(E0s, "r", encoding="utf-8") as f: + atomic_energies_dict = json.load(f) + else: + try: + atomic_energies_dict = ast.literal_eval(E0s) + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError( + f"E0s specified invalidly, error {e} occured" + ) from e else: raise RuntimeError( "E0s not found in training file and not specified in command line" @@ -451,6 +464,14 @@ def create_error_table( "relative F RMSE %", "RMSE Stress (Virials) / meV / A (A^3)", ] + elif table_type == "PerAtomMAEstressvirials": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + "MAE Stress (Virials) / meV / A (A^3)", + ] elif table_type == "TotalMAE": table.field_names = [ "config_type", @@ -515,18 +536,18 @@ def create_error_table( table.add_row( [ name, - f"{metrics['rmse_e'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.2f}", + f"{metrics['rmse_e'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", ] ) elif table_type == "PerAtomRMSE": table.add_row( [ name, - f"{metrics['rmse_e_per_atom'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.2f}", + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", ] ) elif ( @@ -536,10 +557,10 @@ def create_error_table( table.add_row( [ name, - f"{metrics['rmse_e_per_atom'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.2f}", - f"{metrics['rmse_stress'] * 1000:.1f}", + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_stress'] * 1000:8.1f}", ] ) elif ( @@ -549,55 +570,81 @@ def create_error_table( table.add_row( [ name, - f"{metrics['rmse_e_per_atom'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.2f}", - f"{metrics['rmse_virials'] * 1000:.1f}", + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_virials'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_virials'] * 1000:8.1f}", ] ) elif table_type == "TotalMAE": table.add_row( [ name, - f"{metrics['mae_e'] * 1000:.1f}", - f"{metrics['mae_f'] * 1000:.1f}", - f"{metrics['rel_mae_f']:.2f}", + f"{metrics['mae_e'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", ] ) elif table_type == "PerAtomMAE": table.add_row( [ name, - f"{metrics['mae_e_per_atom'] * 1000:.1f}", - f"{metrics['mae_f'] * 1000:.1f}", - f"{metrics['rel_mae_f']:.2f}", + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", ] ) elif table_type == "DipoleRMSE": table.add_row( [ name, - f"{metrics['rmse_mu_per_atom'] * 1000:.2f}", - f"{metrics['rel_rmse_mu']:.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_rmse_mu']:8.1f}", ] ) elif table_type == "DipoleMAE": table.add_row( [ name, - f"{metrics['mae_mu_per_atom'] * 1000:.2f}", - f"{metrics['rel_mae_mu']:.1f}", + f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_mae_mu']:8.1f}", ] ) elif table_type == "EnergyDipoleRMSE": table.add_row( [ name, - f"{metrics['rmse_e_per_atom'] * 1000:.1f}", - f"{metrics['rmse_f'] * 1000:.1f}", - f"{metrics['rel_rmse_f']:.1f}", - f"{metrics['rmse_mu_per_atom'] * 1000:.1f}", - f"{metrics['rel_rmse_mu']:.1f}", + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_mu']:8.1f}", ] ) return table diff --git a/mace/tools/train.py b/mace/tools/train.py index 575fb02d..b38bce16 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -45,11 +45,15 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) + if epoch is None: + inintial_phrase = "Initial" + else: + inintial_phrase = f"Epoch {epoch}" if log_errors == "PerAtomRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -59,7 +63,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress_per_atom={error_stress:8.1f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -69,37 +73,57 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_stress_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_stress = eval_metrics["mae_stress"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_virials_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_virials = eval_metrics["mae_virials"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", ) @@ -139,7 +163,11 @@ def train( if max_grad_norm is not None: logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") - logging.info("Started training") + + logging.info("") + logging.info("===========TRAINING===========") + logging.info("Started training, reporting errors on validation set") + logging.info("Loss metrics on validation set") epoch = start_epoch # # log validation loss before _any_ training diff --git a/mace/tools/utils.py b/mace/tools/utils.py index 65190108..762d9880 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -52,27 +52,42 @@ def setup_logger( directory: Optional[str] = None, rank: Optional[int] = 0, ): + # Create a logger logger = logging.getLogger() - logger.setLevel(level) + logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels + # Create formatters formatter = logging.Formatter( "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) + # Add filter for rank + logger.addFilter(lambda _: rank == 0) + + # Create console handler ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(level) ch.setFormatter(formatter) logger.addHandler(ch) - logger.addFilter(lambda _: (rank == 0)) - - if (directory is not None) and (tag is not None): + if directory is not None and tag is not None: os.makedirs(name=directory, exist_ok=True) - path = os.path.join(directory, tag + ".log") - fh = logging.FileHandler(path) - fh.setFormatter(formatter) - logger.addHandler(fh) + # Create file handler for non-debug logs + main_log_path = os.path.join(directory, f"{tag}.log") + fh_main = logging.FileHandler(main_log_path) + fh_main.setLevel(level) + fh_main.setFormatter(formatter) + logger.addHandler(fh_main) + + # Create file handler for debug logs + debug_log_path = os.path.join(directory, f"{tag}_debug.log") + fh_debug = logging.FileHandler(debug_log_path) + fh_debug.setLevel(logging.DEBUG) + fh_debug.setFormatter(formatter) + fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) + logger.addHandler(fh_debug) class AtomicNumberTable: diff --git a/tests/test_calculator.py b/tests/test_calculator.py index e0763a42..ef048854 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -75,6 +75,7 @@ def trained_model_fixture(tmp_path_factory, fitting_configs): "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp("run_") @@ -137,6 +138,7 @@ def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp("run_") @@ -200,6 +202,7 @@ def trained_dipole_fixture(tmp_path_factory, fitting_configs): "stress_key": "", "dipole_key": "REF_dipole", "error_table": "DipoleRMSE", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp("run_") @@ -265,6 +268,7 @@ def trained_energy_dipole_fixture(tmp_path_factory, fitting_configs): "stress_key": "", "dipole_key": "REF_dipole", "error_table": "EnergyDipoleRMSE", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp("run_") @@ -332,6 +336,7 @@ def trained_committee_fixture(tmp_path_factory, fitting_configs): "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "eval_interval": 2, } tmp_path = tmp_path_factory.mktemp(f"run{seed}_") diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 2267d845..5f39805b 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -66,6 +66,7 @@ def fixture_fitting_configs(): "energy_key": "REF_energy", "forces_key": "REF_forces", "stress_key": "REF_stress", + "eval_interval": 2, }