From 757385a39afa5cf61b0a3e73523e3ff73e748874 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:52:41 +0100 Subject: [PATCH] automatic download of the descriptors --- mace/cli/fine_tuning_select.py | 29 +++++++++++++++-------------- mace/cli/run_train.py | 26 +++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 5fe1f7d0..2cbbb522 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -98,7 +98,7 @@ def parse_args() -> argparse.Namespace: def calculate_descriptors( atoms: t.List[ase.Atoms | ase.Atom], calc: MACECalculator, cutoffs: None | dict ) -> None: - print("Calculating descriptors") + logging.info("Calculating descriptors") for mol in atoms: descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) # average descriptors over atoms for each element @@ -164,8 +164,8 @@ def run( """ Run the farthest point sampling algorithm. """ - print(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape) - print("n_samples", self.n_samples) + logging.info(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape) + logging.info("n_samples", self.n_samples) self.list_index = fpsample.fps_npdu_kdtree_sampling( self.descriptors_dataset.reshape(len(self.atoms_list), -1), self.n_samples ) @@ -207,12 +207,12 @@ def select_samples( if args.filtering_type != None: all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) - print( - "Filtering configurations based on the finetuning set," + logging.info( + "Filtering configurations based on the finetuning set, " f"filtering type: combinations, elements: {all_species_ft}" ) if args.descriptors is not None: - print("Loading descriptors") + logging.info("Loading descriptors") descriptors = np.load(args.descriptors, allow_pickle=True) atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): @@ -222,7 +222,6 @@ def select_samples( for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") ] - else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") atoms_list_pt = [ @@ -233,7 +232,7 @@ def select_samples( else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") if args.descriptors is not None: - print( + logging.info( "Loading descriptors for the pretraining set from {}".format( args.descriptors ) @@ -244,35 +243,37 @@ def select_samples( if args.num_samples is not None and args.num_samples < len(atoms_list_pt): if args.descriptors is None: - print("Calculating descriptors for the pretraining set") + logging.info("Calculating descriptors for the pretraining set") calculate_descriptors(atoms_list_pt, calc, None) descriptors_list = [ atoms.info["mace_descriptors"] for atoms in atoms_list_pt ] - print( + logging.info( "Saving descriptors at {}".format( args.output.replace(".xyz", "descriptors.npy") ) ) np.save(args.output.replace(".xyz", "descriptors.npy"), descriptors_list) - print("Selecting configurations using Farthest Point Sampling") + logging.info("Selecting configurations using Farthest Point Sampling") fps_pt = FPS(atoms_list_pt, args.num_samples) idx_pt = fps_pt.run() - print(f"Selected {len(idx_pt)} configurations") + logging.info(f"Selected {len(idx_pt)} configurations") atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] for atoms in atoms_list_pt: # del atoms.info["mace_descriptors"] atoms.info["pretrained"] = True atoms.info["config_weight"] = args.weight_pt + atoms.info["mace_descriptors"] = None if args.head_pt is not None: atoms.info["head"] = args.head_pt - print("Saving the selected configurations") + logging.info("Saving the selected configurations") ase.io.write(args.output, atoms_list_pt, format="extxyz") - print("Saving a combined XYZ file") + logging.info("Saving a combined XYZ file") for atoms in atoms_list_ft: atoms.info["pretrained"] = False atoms.info["config_weight"] = args.weight_ft + atoms.info["mace_descriptors"] = None if args.head_ft is not None: atoms.info["head"] = args.head_ft atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 7c1d84df..b509ea1c 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -186,6 +186,7 @@ def main() -> None: logging.info(f"Using heads: {heads}") try: checkpoint_url = "https://tinyurl.com/mw2wetc5" + descriptors_url = "https://tinyurl.com/mpe7br4d" cache_dir = os.path.expanduser("~/.cache/mace") checkpoint_url_name = "".join( c @@ -193,6 +194,12 @@ def main() -> None: if c.isalnum() or c in "_" ) cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + descriptors_url_name = "".join( + c + for c in os.path.basename(descriptors_url) + if c.isalnum() or c in "_" + ) + cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}" if not os.path.isfile(cached_dataset_path): os.makedirs(cache_dir, exist_ok=True) # download and save to disk @@ -205,9 +212,26 @@ def main() -> None: f"Dataset download failed, please check the URL {checkpoint_url}" ) logging.info(f"Materials Project dataset to {cached_dataset_path}") + if not os.path.isfile(cached_descriptors_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP descriptors for finetuning") + _, http_msg = urllib.request.urlretrieve( + descriptors_url, cached_descriptors_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Descriptors download failed, please check the URL {descriptors_url}" + ) + logging.info( + f"Materials Project descriptors to {cached_descriptors_path}" + ) dataset_mp = cached_dataset_path + descriptors_mp = cached_descriptors_path msg = f"Using Materials Project dataset with {dataset_mp}" logging.info(msg) + msg = f"Using Materials Project descriptors with {descriptors_mp}" + logging.info(msg) args_samples = { "configs_pt": dataset_mp, "configs_ft": args.train_file, @@ -220,7 +244,7 @@ def main() -> None: "weight_ft": 1.0, "filtering_type": "combination", "output": f"{cache_dir}/mp_finetuning.xyz", - "descriptors": r"D:\Work\mace_mp\descriptors.npy", + "descriptors": descriptors_mp, "device": args.device, "default_dtype": args.default_dtype, }