From 921f63df61cd2ac152771782f99a49abfd68d72f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:16:45 +0000 Subject: [PATCH] add option to rescale number of ft sample --- mace/cli/run_train.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index e8319ac7..1c0898b7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -335,8 +335,21 @@ def run(args: argparse.Namespace) -> None: ) head_config_pt.collections = collections head_configs.append(head_config_pt) + + ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train) + if ratio_pt_ft < 0.1: + logging.warning( + f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, " + f"increasing the number of configurations in the pt_train_file by a factor of {int(0.1 / ratio_pt_ft)}" + ) + for head_config in head_configs: + if head_config.head_name == "pt_head": + continue + head_config.collections.train += head_config.collections.train * int( + 0.1 / ratio_pt_ft + ) logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" + f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}" ) # Atomic number table