diff --git a/mace/cli/preprocess_data_mpi.py b/mace/cli/preprocess_data_mpi.py index 80d43990..9fef427e 100644 --- a/mace/cli/preprocess_data_mpi.py +++ b/mace/cli/preprocess_data_mpi.py @@ -8,27 +8,32 @@ import multiprocessing as mp import os import random +from collections.abc import Iterable from glob import glob -from typing import List, Tuple +from itertools import islice +from typing import Generator, List, Tuple import h5py import numpy as np import tqdm +from ase import Atoms +from ase.io import iread +from mpi4py import MPI from mace import data, tools -from mace.data.utils import save_configurations_as_HDF5, config_from_atoms, Configurations +from mace.data.utils import ( + Configurations, + config_from_atoms, + save_configurations_as_HDF5, +) from mace.modules import compute_statistics from mace.tools import torch_geometric -from mace.tools.scripts_utils import get_atomic_energies, SubsetCollection +from mace.tools.scripts_utils import SubsetCollection, get_atomic_energies from mace.tools.utils import AtomicNumberTable -from ase.io import iread -from ase import Atoms - -from mpi4py import MPI -from itertools import islice -from collections.abc import Iterable -from typing import Generator +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() def compute_stats_target( @@ -118,7 +123,8 @@ def get_prime_factors(n: int): # break # yield chunk -def chunkify(iterable: Iterable, num_chunks: int) -> Generator[List, None, None]: + +def chunkify(data: list, size: int) -> list: """ Splits an iterable into a specified number of chunks without knowing its length. @@ -127,21 +133,25 @@ def chunkify(iterable: Iterable, num_chunks: int) -> Generator[List, None, None] num_chunks (int): The number of chunks to create. Yields: - Generator[List]: A generator yielding chunks of the input iterable. + Generator: A generator yielding chunks of the input iterable. """ - it = iter(iterable) - while True: - chunk = list(islice(it, num_chunks)) - if not chunk: - break - yield chunk + # def create_chunks(data, size): + total_items = len(data) + chunk_size = total_items // size + remainder = total_items % size + + chunks = [] + for i in range(size): + start_index = i * chunk_size + min(i, remainder) + end_index = start_index + chunk_size + (1 if i < remainder else 0) + chunks.append(data[start_index:end_index]) + + return chunks def process_chunk( - chunk: List[Atoms], - config_type_weights: dict, - args: argparse.Namespace - ) -> Configurations: + chunk: List[Atoms], config_type_weights: dict, args: argparse.Namespace +) -> Configurations: """ Processes a single chunk of data. Modify this function to include the actual processing logic. @@ -169,22 +179,18 @@ def process_chunk( ) return atoms_list -def read_configs( - file: str, - args: argparse.Namespace, - config_type_weights: dict - ) -> Configurations: - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - size = comm.Get_size() +def read_configs( + file: str, args: argparse.Namespace, config_type_weights: dict +) -> Configurations | None: # MPI: Read the xyz file if rank == 0: logging.info(f"Reading the xyz file {file}...") - iter_atoms = iread(file, ":") - chunks = list(chunkify(iter_atoms, size)) + + data = list(iread(file, ":")) + chunks = chunkify(data, size) else: chunks = None @@ -193,20 +199,21 @@ def read_configs( # Each process processes its chunk result = process_chunk(chunk, config_type_weights, args) - print(f"Process {rank} processed chunk {chunk}") + print(f"Process {rank} processed chunk of size {len(chunk)}") # Gather results from all processes configs = comm.gather(result, root=0) if rank == 0: logging.info("Gathered configurations from all processes") + return configs + else: + return - return configs def get_z_table( - configs: Configurations, - args: argparse.Namespace - ) -> AtomicNumberTable: + configs: Configurations, args: argparse.Namespace +) -> AtomicNumberTable | None: """ Extracts the atomic numbers from the configurations and creates an atomic number table. @@ -217,18 +224,14 @@ def get_z_table( AtomicNumberTable: The atomic number table. """ - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - size = comm.Get_size() - if args.atomic_numbers is None: # MPI: Extract all the atomic numbers if rank == 0: logging.info("Extracting atomic numbers...") - chunks = list(chunkify(configs, size)) + chunks = chunkify(configs, size) else: chunks = None - + chunk = comm.scatter(chunks, root=0) z_set = set() @@ -236,16 +239,19 @@ def get_z_table( for z in zs: z_set.add(z) - print(f"Process {rank} processed chunk {chunk}") + print(f"Process {rank} processed chunk of size {len(chunk)}") z_sets = comm.gather(z_set, root=0) - z_table = tools.get_atomic_number_table_from_zs( - z for z_set in z_sets for z in z_set - ) - if rank == 0: + z_table = tools.get_atomic_number_table_from_zs( + z for z_set in z_sets for z in z_set + ) logging.info("Extracted atomic numbers") + + comm.bcast(z_table, root=0) + + comm.barrier() else: logging.info("Using atomic numbers from command line argument") zs_list = ast.literal_eval(args.atomic_numbers) @@ -254,6 +260,7 @@ def get_z_table( return z_table + # def write_hdf5( # configs: Configurations, # filename: str @@ -307,145 +314,171 @@ def main() -> None: if not os.path.exists(args.h5_prefix + sub_dir): os.makedirs(args.h5_prefix + sub_dir) - - # TODO: MPI # - iread the xyz file # - save hdf5 files for each process # - extract all the atomic numbers - # - compute statistics + # - compute statistics + + comm.barrier() train_configs = read_configs(args.train_file, args, config_type_weights) + comm.barrier() + if args.valid_file is not None: valid_configs = read_configs(args.valid_file, args, config_type_weights) logging.info(f"Number of training configurations: {len(train_configs)}") logging.info(f"Number of validation configurations: {len(valid_configs)}") else: - logging.info(f"Using {args.valid_fraction*100}% of the training data for validation") - train_configs, valid_configs = data.random_train_valid_split( - train_configs, args.valid_fraction, seed=args.seed - ) - logging.info(f"Number of training configurations: {len(train_configs)}") - logging.info(f"Number of validation configurations: {len(valid_configs)}") + if rank == 0: + logging.info( + f"Using {args.valid_fraction*100}% of the training data for validation" + ) + + train_configs, valid_configs = data.random_train_valid_split( + train_configs, args.valid_fraction, seed=args.seed + ) + logging.info(f"Number of training configurations: {len(train_configs)}") + logging.info(f"Number of validation configurations: {len(valid_configs)}") + + comm.barrier() if args.test_file is not None: test_configs = read_configs(args.test_file, args, config_type_weights) - test_configs = data.test_config_types(test_configs) - logging.info(f"Number of test configurations: {len(test_configs)}") + + if rank == 0: + test_configs = data.test_config_types(test_configs) + logging.info(f"Number of test configurations: {len(test_configs)}") else: test_configs = [] logging.info("No test set provided") - # Atomic number table + comm.barrier() - configs = train_configs + valid_configs + # Atomic number table + assert isinstance(train_configs, list) + assert isinstance(valid_configs, list) + configs = [c for c in train_configs] + [c for c in valid_configs] z_table = get_z_table(configs, args) + comm.barrier() - collections = SubsetCollection( - train=train_configs, - valid=valid_configs, - tests=test_configs - ) - - + if rank == 0: - if args.shuffle: - random.shuffle(train_configs) + collections = SubsetCollection( + train=train_configs, valid=valid_configs, tests=test_configs + ) + if args.shuffle: + random.shuffle(collections.train) - # split collections.train into batches and save them to hdf5 - split_train = np.array_split(collections.train, args.num_process) - drop_last = False - if len(collections.train) % 2 == 1: - drop_last = True + # split collections.train into batches and save them to hdf5 + split_train = np.array_split(collections.train, args.num_process) + drop_last = False + if len(collections.train) % 2 == 1: + drop_last = True - # Define Task for Multiprocessiing - def multi_train_hdf5(process): - with h5py.File(args.h5_prefix + "train/train_" + str(process)+".h5", "w") as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_train[process], process, f) - - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_train_hdf5, args=[i]) - p.start() - processes.append(p) + # Define Task for Multiprocessiing + def multi_train_hdf5(process): + with h5py.File( + args.h5_prefix + "train/train_" + str(process) + ".h5", "w" + ) as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_train[process], process, f) - for i in processes: - i.join() + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_train_hdf5, args=[i]) + p.start() + processes.append(p) + for i in processes: + i.join() - logging.info("Computing statistics") - atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) - atomic_energies: np.ndarray = np.array( - [atomic_energies_dict[z] for z in z_table.zs] - ) - logging.info(f"Atomic energies: {atomic_energies.tolist()}") - _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] - avg_num_neighbors, mean, std=pool_compute_stats(_inputs) - logging.info(f"Average number of neighbors: {avg_num_neighbors}") - logging.info(f"Mean: {mean}") - logging.info(f"Standard deviation: {std}") - - # save the statistics as a json - statistics = { - "atomic_energies": str(atomic_energies_dict), - "avg_num_neighbors": avg_num_neighbors, - "mean": mean, - "std": std, - "atomic_numbers": str(z_table.zs), - "r_max": args.r_max, - } - - with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 - json.dump(statistics, f) - - logging.info("Preparing validation set") - if args.shuffle: - random.shuffle(collections.valid) - split_valid = np.array_split(collections.valid, args.num_process) - drop_last = False - if len(collections.valid) % 2 == 1: - drop_last = True + logging.info("Computing statistics") + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + atomic_energies: np.ndarray = np.array( + [atomic_energies_dict[z] for z in z_table.zs] + ) + logging.info(f"Atomic energies: {atomic_energies.tolist()}") + _inputs = [ + args.h5_prefix + "train", + z_table, + args.r_max, + atomic_energies, + args.batch_size, + args.num_process, + ] + avg_num_neighbors, mean, std = pool_compute_stats(_inputs) + logging.info(f"Average number of neighbors: {avg_num_neighbors}") + logging.info(f"Mean: {mean}") + logging.info(f"Standard deviation: {std}") + + # save the statistics as a json + statistics = { + "atomic_energies": str(atomic_energies_dict), + "avg_num_neighbors": avg_num_neighbors, + "mean": mean, + "std": std, + "atomic_numbers": str(z_table.zs), + "r_max": args.r_max, + } + + with open( + args.h5_prefix + "statistics.json", "w" + ) as f: # pylint: disable=W1514 + json.dump(statistics, f) + + logging.info("Preparing validation set") + if args.shuffle: + random.shuffle(collections.valid) + split_valid = np.array_split(collections.valid, args.num_process) + drop_last = False + if len(collections.valid) % 2 == 1: + drop_last = True + + def multi_valid_hdf5(process): + with h5py.File( + args.h5_prefix + "val/val_" + str(process) + ".h5", "w" + ) as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_valid[process], process, f) - def multi_valid_hdf5(process): - with h5py.File(args.h5_prefix + "val/val_" + str(process)+".h5", "w") as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_valid[process], process, f) + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_valid_hdf5, args=[i]) + p.start() + processes.append(p) - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_valid_hdf5, args=[i]) - p.start() - processes.append(p) + for i in processes: + i.join() - for i in processes: - i.join() + if args.test_file is not None: - if args.test_file is not None: - def multi_test_hdf5(process, name): - with h5py.File(args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w") as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_test[process], process, f) + def multi_test_hdf5(process, name): + with h5py.File( + args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" + ) as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_test[process], process, f) - logging.info("Preparing test sets") - for name, subset in collections.tests: - drop_last = False - if len(subset) % 2 == 1: - drop_last = True - split_test = np.array_split(subset, args.num_process) + logging.info("Preparing test sets") + for name, subset in collections.tests: + drop_last = False + if len(subset) % 2 == 1: + drop_last = True + split_test = np.array_split(subset, args.num_process) - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_test_hdf5, args=[i, name]) - p.start() - processes.append(p) + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_test_hdf5, args=[i, name]) + p.start() + processes.append(p) - for i in processes: - i.join() + for i in processes: + i.join() if __name__ == "__main__": - main() \ No newline at end of file + main()