From 544bdaeced6f64d71f22d533f9599ae21b6d5711 Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Wed, 26 Jun 2024 15:33:06 -0500 Subject: [PATCH] KIM Trainer and tests --- kliff/dataset/dataset.py | 10 +- kliff/dataset/weight.py | 16 +- kliff/models/kim.py | 62 +++-- kliff/trainer/__init__.py | 2 +- kliff/trainer/_kim_loss_functions.py | 17 ++ kliff/trainer/base_trainer.py | 56 +++-- kliff/trainer/kim_trainer.py | 214 ++++++++++++++++++ kliff/trainer/lightning_trainer.py | 22 +- .../trainer_data/example_config_ase_kim.yaml | 59 +++++ .../example_config_ase_lightning_gnn.yaml | 2 - tests/trainer/test_kim_trainer.py | 120 ++++++++++ tests/trainer/test_lightning_trainer.py | 3 - 12 files changed, 512 insertions(+), 71 deletions(-) create mode 100644 kliff/trainer/_kim_loss_functions.py create mode 100644 kliff/trainer/kim_trainer.py create mode 100644 tests/test_data/trainer_data/example_config_ase_kim.yaml create mode 100644 tests/trainer/test_kim_trainer.py diff --git a/kliff/dataset/dataset.py b/kliff/dataset/dataset.py index cfcf8353..4b95a9df 100644 --- a/kliff/dataset/dataset.py +++ b/kliff/dataset/dataset.py @@ -1092,7 +1092,7 @@ def add_weights(configurations: List[Configuration], path: Union[Path, str]): """ if path is None: - logger.info("No weights provided.") + logger.info("No explicit weights datafile provided.") return weights_data = np.genfromtxt(path, names=True) @@ -1293,10 +1293,10 @@ def get_dataset_from_manifest(dataset_manifest: dict) -> "Dataset": weights = Path(weights) elif isinstance(weights, dict): weights = Weight( - config_weight=weights.get("config", 0.0), - energy_weight=weights.get("energy", 0.0), - forces_weight=weights.get("forces", 0.0), - stress_weight=weights.get("stress", 0.0), + config_weight=weights.get("config", None), + energy_weight=weights.get("energy", None), + forces_weight=weights.get("forces", None), + stress_weight=weights.get("stress", None), ) else: raise DatasetError("Weights must be a path or a dictionary.") diff --git a/kliff/dataset/weight.py b/kliff/dataset/weight.py index d775c29f..da235617 100644 --- a/kliff/dataset/weight.py +++ b/kliff/dataset/weight.py @@ -1,5 +1,5 @@ import copy -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import numpy as np from loguru import logger @@ -27,10 +27,10 @@ class Weight: def __init__( self, - config_weight: float = 1.0, - energy_weight: float = 1.0, - forces_weight: float = 1.0, - stress_weight: float = 1.0, + config_weight: Union[float, None] = 1.0, + energy_weight: Union[float, None] = 1.0, + forces_weight: Union[float, None] = 1.0, + stress_weight: Union[float, None] = 1.0, ): self._config_weight = config_weight self._energy_weight = energy_weight @@ -91,11 +91,11 @@ def _check_compute_flag(self, config): # If the weight are really small, but not zero, then warn the user. Zero weight # usually means that the property is used. - if config._energy is not None and np.all(ew < 1e-12): + if config._energy is not None and ew is not None and np.all(ew < 1e-12): logger.warning(msg.format("energy", ew)) - if config._forces is not None and np.all(fw < 1e-12): + if config._forces is not None and fw is not None and np.all(fw < 1e-12): logger.warning(msg.format("forces", fw)) - if config._stress is not None and np.all(sw < 1e-12): + if config._stress is not None and sw is not None and np.all(sw < 1e-12): logger.warning(msg.format("stress", sw)) diff --git a/kliff/models/kim.py b/kliff/models/kim.py index be73d268..53a0cd76 100644 --- a/kliff/models/kim.py +++ b/kliff/models/kim.py @@ -622,10 +622,10 @@ def write_kim_model(self, path: Path = None): Write out a KIM model that can be used directly with the kim-api. This function typically write two files to `path`: (1) CMakeLists.txt, and (2) - a parameter file like A.model_params. `path` will be created if it does not exist. + a parameter file like A.params. `path` will be created if it does not exist. Args: - path: Path to the a directory to store the model. If `None`, it is set to + path: Path to a directory to store the model. If `None`, it is set to `./MODEL_NAME_kliff_trained`, where `MODEL_NAME` is the `model_name` that provided at the initialization of this class. @@ -696,7 +696,11 @@ def __call__( return kim_ca_instance.results @staticmethod - def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): + def get_model_from_manifest( + model_manifest: dict, + param_manifest: dict = None, + is_model_tarfile: bool = False, + ): """ Get the model from a configuration. If it is a valid KIM model, it will return the KIMModel object. If it is a TorchML model, it will return the torch @@ -710,7 +714,6 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): Example `model_manifest`: ```yaml model: - model_type: kim # kim or torch model_path: ./model.tar.gz # path to the model tarball model_name: SW_StillingerWeber_1985_Si__MO_405512056662_006 # KIM model name, installed if missing model_collection: "user" @@ -735,12 +738,12 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): Args: model_manifest: configuration object param_manifest: parameter transformation configuration + is_model_tarfile: whether the model is a tarball Returns: Model object """ model_name: Union[None, str] = model_manifest.get("name", None) - model_type: Union[None, str] = model_manifest.get("type", None) model_path: Union[None, str, Path] = model_manifest.get("path", None) model_driver = KIMModel.get_model_driver_name(model_name) model_collection = model_manifest.get("collection") @@ -754,35 +757,9 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): f"Model driver {model_driver} not supported for KIMModel training." ) - # is model a tarball? - if model_path is not None: - model_path = Path(model_path) - if model_path.suffix == ".tar": - model_type = "tar" - # ensure model is installed - if model_type.lower() == "kim": - # is it a tar file? - is_model_installed = is_kim_model_installed(model_name) - if is_model_installed: - logger.info(f"Model {model_name} is already installed, continuing ...") - else: - logger.info( - f"Model {model_name} not installed on system, attempting to installing ..." - ) - was_install_success = install_kim_model(model_name, model_collection) - if not was_install_success: - logger.error( - f"Model {model_name} not found in the KIM API collections. Please check the model name and try again." - ) - raise KIMModelError(f"Model {model_name} not found.") - else: - logger.info( - f"Model {model_name} installed in {model_collection} collection." - ) - - elif model_type.lower() == "tar": - archive_content = tarfile.open(model_path + "/" + model_name) + if is_model_tarfile: + archive_content = tarfile.open(model_path) model = archive_content.getnames()[0] archive_content.extractall(model_path) subprocess.run( @@ -798,8 +775,25 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): logger.info( f"Tarball Model {model} installed in {model_collection} collection." ) + + is_model_installed = is_kim_model_installed(model_name) + + if is_model_installed: + logger.info(f"Model {model_name} is already installed, continuing ...") else: - raise KIMModelError(f"Model type {model_type} not supported.") + logger.info( + f"Model {model_name} not installed on system, attempting to installing ..." + ) + was_install_success = install_kim_model(model_name, model_collection) + if not was_install_success: + logger.error( + f"Model {model_name} not found in the KIM API collections. Please check the model name and try again." + ) + raise KIMModelError(f"Model {model_name} not found.") + else: + logger.info( + f"Model {model_name} installed in {model_collection} collection." + ) model = KIMModel(model_name) diff --git a/kliff/trainer/__init__.py b/kliff/trainer/__init__.py index 20397e25..f3b1e7b9 100644 --- a/kliff/trainer/__init__.py +++ b/kliff/trainer/__init__.py @@ -1,5 +1,5 @@ from .base_trainer import Trainer +from .kim_trainer import KIMTrainer from .lightning_trainer import GNNLightningTrainer -# from .kim_trainer import KIMTrainer # from .torch_trainer import DNNTrainer diff --git a/kliff/trainer/_kim_loss_functions.py b/kliff/trainer/_kim_loss_functions.py new file mode 100644 index 00000000..e7537e5b --- /dev/null +++ b/kliff/trainer/_kim_loss_functions.py @@ -0,0 +1,17 @@ +import numpy as np + + +def MSE_loss( + predictions: np.ndarray, + targets: np.ndarray, +) -> np.ndarray: + r""" + Compute the mean squared error (MSE) of the residuals. + + Args: + + Returns: + The MSE of the residuals. + """ + residuals = predictions - targets + return np.mean(residuals**2) diff --git a/kliff/trainer/base_trainer.py b/kliff/trainer/base_trainer.py index 70c7e354..d9c52ebc 100644 --- a/kliff/trainer/base_trainer.py +++ b/kliff/trainer/base_trainer.py @@ -16,6 +16,7 @@ from loguru import logger from kliff.dataset import Dataset +from kliff.dataset.weight import Weight if TYPE_CHECKING: from kliff.transforms.configuration_transforms import ConfigurationTransform @@ -104,9 +105,9 @@ def __init__(self, training_manifest: dict, model=None): # model variables self.model_manifest: dict = { - "type": "kim", "name": None, "path": None, + "collection": "user", } self.model: Callable = model @@ -144,7 +145,6 @@ def __init__(self, training_manifest: dict, model=None): } self.optimizer_manifest: dict = { - "provider": "scipy", "name": None, "learning_rate": None, "kwargs": None, @@ -170,6 +170,7 @@ def __init__(self, training_manifest: dict, model=None): self.export_manifest: dict = { "model_name": None, "model_path": None, + "generate_tarball": False, } # state variables @@ -426,12 +427,29 @@ def setup_dataset(self): TODO: ColabFit integration for extreme scale datasets. """ dataset_module_manifest = deepcopy(self.dataset_manifest) - dataset_module_manifest["weights"] = self.loss_manifest["weights"] + # dataset_module_manifest["weights"] = self.loss_manifest["weights"] + + weights = self.loss_manifest["weights"] + + if weights is not None: + if isinstance(weights, str): + weights = Path(weights) + elif isinstance(weights, dict): + weights = Weight( + config_weight=weights.get("config", None), + energy_weight=weights.get("energy", None), + forces_weight=weights.get("forces", None), + stress_weight=weights.get("stress", None), + ) + else: + raise TrainerError("Weights must be a path or a dictionary.") + dataset_list = _parallel_read( self.dataset_manifest["path"], num_chunks=self.optimizer_manifest["num_workers"], energy_key=self.dataset_manifest.get("keys", {}).get("energy", "energy"), forces_key=self.dataset_manifest.get("keys", {}).get("forces", "forces"), + weights=weights, ) self.dataset = deepcopy(dataset_list[0]) for ds in dataset_list[1:]: @@ -637,17 +655,18 @@ def setup_dataset_split(self): fmt="%d", ) - if isinstance(val_indices, str): - self.dataset_sample_manifest["indices_files"]["val"] = val_indices - else: - self.dataset_sample_manifest["indices_files"][ - "val" - ] = f"{self.current['run_dir']}/val_indices.txt" - np.savetxt( - self.dataset_sample_manifest["indices_files"]["val"], - val_indices, - fmt="%d", - ) + if val_size > 0: + if isinstance(val_indices, str): + self.dataset_sample_manifest["indices_files"]["val"] = val_indices + else: + self.dataset_sample_manifest["indices_files"][ + "val" + ] = f"{self.current['run_dir']}/val_indices.txt" + np.savetxt( + self.dataset_sample_manifest["indices_files"]["val"], + val_indices, + fmt="%d", + ) def loss(self, *args, **kwargs): raise TrainerError("loss not implemented.") @@ -708,6 +727,7 @@ def _parallel_read( num_chunks=None, energy_key="Energy", forces_key="forces", + weights=None, ) -> List[Dataset]: """ Read and transform frames in parallel. Returns n_chunks of datasets. @@ -742,18 +762,22 @@ def _parallel_read( with multiprocessing.Pool(processes=num_chunks) as pool: ds = pool.starmap( _read_frames, - [(file_path, start, end, energy_key, forces_key) for start, end in chunks], + [ + (file_path, start, end, energy_key, forces_key, weights) + for start, end in chunks + ], ) return ds -def _read_frames(file_path, start, end, energy_key, forces_key): +def _read_frames(file_path, start, end, energy_key, forces_key, weights): ds = Dataset.from_ase( path=file_path, energy_key=energy_key, forces_key=forces_key, slices=slice(start, end), + weight=weights, ) return ds diff --git a/kliff/trainer/kim_trainer.py b/kliff/trainer/kim_trainer.py new file mode 100644 index 00000000..c2c34774 --- /dev/null +++ b/kliff/trainer/kim_trainer.py @@ -0,0 +1,214 @@ +import importlib +import tarfile +from pathlib import Path + +import numpy as np +from loguru import logger + +from kliff.models import KIMModel + +from ._kim_loss_functions import MSE_loss +from .base_trainer import Trainer, TrainerError + +SCIPY_MINIMIZE_METHODS = [ + "Nelder-Mead", + "Powell", + "CG", + "BFGS", + "Newton-CG", + "L-BFGS-B", + "TNC", + "COBYLA", + "SLSQP", + "trust-constr", + "dogleg", + "trust-ncg", + "trust-exact", + "trust-krylov", +] + + +class KIMTrainer(Trainer): + """ + This class extends the base Trainer class for training OpenKIM physics based models. + It will use the scipy optimizers. It will perform a check to exclude TorchML model + driver based models, as they would be handled by Torch based trainers. It can read model tarballs + as well as export the models as tarballs for ease of use. It will use the KIMModel + class to load the model and set the parameters. It also provides explicit interface + for parameters transformation. + + Args: + configuration (dict): The configuration dictionary. + """ + + def __init__(self, configuration: dict): + self.model_driver_name = None + self.parameters = None + self.mutable_parameters_list = [] + self.use_energy = True + self.use_forces = False + self.use_stress = False + self.is_model_tarfile = False + + super().__init__(configuration) + + self.loss_function = self._get_loss_fn() + self.result = None + + def setup_model(self): + """ + Load either the installed KIM model, or install it from the source. If the model + driver required is TorchML* family, then it will raise an error, as it should be + handled by the DNNTrainer, or GNNLightningTrainer. + + Path can be a folder containing the model, or a tar file. The model name is the KIM + model name. + """ + if self.model_manifest["path"]: + try: + self.is_model_tarfile = tarfile.is_tarfile(self.model_manifest["path"]) + except (IsADirectoryError, TypeError) as e: + self.is_model_tarfile = False + logger.debug(f"Model path is not a tarfile: {e}") + + # check for unsupported model drivers + self.model = KIMModel.get_model_from_manifest( + self.model_manifest, self.transform_manifest, self.is_model_tarfile + ) + + self.parameters = self.model.get_model_params() + + def setup_optimizer(self): + """ + Set up the optimizer based on the provided information. If the optimizer name is + not provided, it will raise an error. It will use the ~:class:~scipy.optimize + class for optimizers. It will raise an error if the optimizer is not supported. + """ + if self.optimizer_manifest["name"] not in SCIPY_MINIMIZE_METHODS: + raise TrainerError( + f"Optimizer not supported: {self.optimizer_manifest['name']}." + ) + optimizer_lib = importlib.import_module(f"scipy.optimize") + self.optimizer = getattr(optimizer_lib, "minimize") + # TODO: LM-Geodesic optimizer + + def loss(self, x: np.ndarray) -> float: + """ + Compute the loss function for the given parameters. It sets the KIM model + parameters, compute the desired loss function doe all trainable properties + and return the total loss after scaling losses with ~:class:~kliff.configuration.Weight. + + TODO: + Include MPI support. + + Args: + x (np.ndarray): The model parameters. + + Returns: + float: The total loss. + """ + # set the parameters + self.model.update_model_params(x) + # compute the loss + loss = 0.0 + for configuration in self.train_dataset: + compute_energy = True if configuration.weight.energy_weight else False + compute_forces = True if configuration.weight.forces_weight else False + compute_stress = True if configuration.weight.stress_weight else False + + prediction = self.model( + configuration, + compute_energy=compute_energy, + compute_forces=compute_forces, + compute_stress=compute_stress, + ) + + if configuration.weight.energy_weight: + loss += configuration.weight.energy_weight * self.loss_function( + prediction["energy"], configuration.energy + ) + if configuration.weight.forces_weight: + loss += configuration.weight.forces_weight * self.loss_function( + prediction["forces"], configuration.forces + ) + if configuration.weight.stress_weight: + loss += configuration.weight.stress_weight * self.loss_function( + prediction["stress"], configuration.stress + ) + loss *= configuration.weight.config_weight + + return loss + + def checkpoint(self, *args, **kwargs): + TrainerError("checkpoint not implemented.") + + def train_step(self, *args, **kwargs): + TrainerError("train_step not implemented.") + + def validation_step(self, *args, **kwargs): + TrainerError("validation_step not implemented.") + + def get_optimizer(self, *args, **kwargs): + TrainerError("get_optimizer not implemented.") + + def train(self): + """ + Train the model using the provided optimizer. It will set the model parameters + to the optimal values found by the optimizer. It will log the optimization + status and the message. It will raise an error if the optimization fails. + + TODO: + Include MPI support. + Log loss trajectory for KIM models. + """ + + def _wrapper_func(x): + return self.loss(x) + + x = self.model.get_opt_params() + options = self.optimizer_manifest.get("kwargs", {}) + options["options"] = { + "maxiter": self.optimizer_manifest["epochs"], + "disp": self.current["verbose"], + } + self.result = self.optimizer( + _wrapper_func, x, method=self.optimizer_manifest["name"], **options + ) + + if self.result.success: + logger.info(f"Optimization successful: {self.result.message}") + self.model.update_model_params(self.result.x) + else: + logger.error(f"Optimization failed: {self.result.message}") + + def _get_loss_fn(self) -> callable: + """ + Get the loss function based on the provided loss manifest. It will raise an error + if the loss function is not supported. + + Returns: + function: The loss function. + """ + if self.loss_manifest["function"].lower() == "mse": + return MSE_loss + else: + raise TrainerError( + f"Loss function {self.loss_manifest['function']} not supported." + ) + + def save_kim_model(self): + """ + Save the KIM model to the provided path. It will also generate a tarball if + specified in the export manifest. + """ + path = ( + Path(self.export_manifest["model_path"]) + / self.export_manifest["model_name"] + ) + self.model.write_kim_model(path) + if self.export_manifest["generate_tarball"]: + tarfile_path = path.with_suffix(".tar.gz") + with tarfile.open(tarfile_path, "w:gz") as tar: + tar.add(path, arcname=path.name) + logger.info(f"Model tarball saved: {tarfile_path}") + logger.info(f"KIM model saved at {path}") diff --git a/kliff/trainer/lightning_trainer.py b/kliff/trainer/lightning_trainer.py index 38045993..0cfffbe0 100644 --- a/kliff/trainer/lightning_trainer.py +++ b/kliff/trainer/lightning_trainer.py @@ -3,6 +3,7 @@ # This is temporary fix till torch 1 -> 2 migration is complete import importlib.metadata import os +import tarfile from copy import deepcopy from typing import Any, Dict, List, Tuple, Union @@ -253,7 +254,7 @@ class GNNLightningTrainer(Trainer): callbacks. """ - def __init__(self, manifest, model): + def __init__(self, manifest, model=None): """ Initialize the GNNLightningTrainer. @@ -281,6 +282,14 @@ def setup_model(self): with the model, and the training parameters. """ # if dict has key ema, then set ema to True, decay to the dict value, else set ema false + if not self.model: + try: + self.model = torch.jit.load(self.model_manifest["model_path"]) + except ValueError: + raise TrainerError( + "No model was provided, and model_path is not a valid TorchScript model." + ) + ema = True if self.optimizer_manifest.get("ema", False) else False if ema: ema_decay = self.optimizer_manifest.get("ema_decay", 0.99) @@ -504,6 +513,11 @@ def save_kim_model(self, path: str = "kim-model"): with open(f"{path}/CMakeLists.txt", "w") as f: f.write(cmakefile) + if self.export_manifest["generate_tarball"]: + tarball_name = f"{path}.tar.gz" + with tarfile.open(tarball_name, "w:gz") as tar: + tar.add(path, arcname=os.path.basename(path)) + logger.info(f"Model tarball saved: {tarball_name}") logger.info(f"KIM model saved at {path}") def setup_optimizer(self): @@ -512,4 +526,8 @@ def setup_optimizer(self): def seed_all(self): super().seed_all() - pl.seed_everything(self.current["seed"]) + pl.seed_everything(self.workspace["seed"]) + + +# TODO: Custom loss (via torchmetrics)? +# TODO: switch str everywhere to Path diff --git a/tests/test_data/trainer_data/example_config_ase_kim.yaml b/tests/test_data/trainer_data/example_config_ase_kim.yaml new file mode 100644 index 00000000..5ce814d0 --- /dev/null +++ b/tests/test_data/trainer_data/example_config_ase_kim.yaml @@ -0,0 +1,59 @@ +workspace: + name: test_run # Name of the base workspace folder, where all the runs will be stored + seed: 12345 # Seed for random number generator, all + +dataset: + type: ase # ase or path or colabfit + path: "../test_data/configs/Si_4.xyz" # Path to the dataset, ignored for colabfit + save: False # Save processed dataset to a file + keys: + energy: Energy # Key for energy, if ase dataset is used + forces: force # Key for forces, if ase dataset is used + +model: + collection: user + path: ./ + name: SW_StillingerWeber_1985_Si__MO_405512056662_006 # KIM model name, installed if missing + +transforms: + parameter: # optional for KIM models, list of parameters to optimize + - A # dict means the parameter is transformed + - B # these are the parameters that are not transformed + - sigma: + transform_name: LogParameterTransform + value: 2.0 + bounds: [[1.0, 10.0]] + +training: + loss: + function: MSE # optional: path to loss function file? + weights: # optional: path to weights file + config: 1.0 + energy: 1.0 + forces: 1.0 + normalize_per_atom: true + optimizer: + name: L-BFGS-B + learning_rate: + kwargs: + tol: 1.e-6 # 1. is necessary, 1e-6 is treated as string + + training_dataset: + train_size: 3 # Number of training samples + train_indices: # files with indices [optional] + val_dataset: + val_size: 1 # Number of validation samples + val_indices: "none" # files with indices [optional] + + batch_size: 1 + epochs: 1000 # maxiter + device: cpu + num_workers: 2 + chkpt_interval: 1 + stop_condition: + verbose: False + +export: # optional: export the trained model + generate_tarball: True + model_path: ./ + model_name: SW_StillingerWeber_trained_1985_Si__MO_405512056662_006 diff --git a/tests/test_data/trainer_data/example_config_ase_lightning_gnn.yaml b/tests/test_data/trainer_data/example_config_ase_lightning_gnn.yaml index 127672f8..bf132e8f 100644 --- a/tests/test_data/trainer_data/example_config_ase_lightning_gnn.yaml +++ b/tests/test_data/trainer_data/example_config_ase_lightning_gnn.yaml @@ -17,7 +17,6 @@ dataset: # stress: virial # Key for stress, if ase dataset is used model: - type: torch # torch or tar path: ./ name: "TorchGNN" # Just a name for the model input_args: @@ -46,7 +45,6 @@ training: loss_traj: False optimizer: name: Adam - provider: torch learning_rate: 1.e-3 training_dataset: diff --git a/tests/trainer/test_kim_trainer.py b/tests/trainer/test_kim_trainer.py new file mode 100644 index 00000000..d3d80085 --- /dev/null +++ b/tests/trainer/test_kim_trainer.py @@ -0,0 +1,120 @@ +import tarfile +from pathlib import Path + +import numpy as np +import pytest +import yaml + +from kliff.models import KIMModel +from kliff.trainer import KIMTrainer + + +def test_trainer(): + """ + Basic tests for proper initialization of the Trainer module + """ + manifest_file = filename = ( + Path(__file__) + .parents[1] + .joinpath("test_data/trainer_data/example_config_ase_kim.yaml") + ) + model = KIMModel("SW_StillingerWeber_1985_Si__MO_405512056662_006") + + manifest = yaml.safe_load(open(manifest_file, "r")) + + trainer = KIMTrainer(manifest) + + # check basic initialization + assert trainer.model.model_name == model.model_name + + default_model_params = model.model_params + trainer_params = trainer.model.model_params + + assert np.allclose(trainer_params["A"], default_model_params["A"]) + assert np.allclose(trainer_params["B"], default_model_params["B"]) + assert np.allclose(trainer_params["p"], default_model_params["p"]) + assert np.allclose(trainer_params["q"], default_model_params["q"]) + assert np.allclose(trainer_params["gamma"], default_model_params["gamma"]) + assert np.allclose(trainer_params["cutoff"], default_model_params["cutoff"]) + assert np.allclose(trainer_params["lambda"], default_model_params["lambda"]) + assert np.allclose(trainer_params["costheta0"], default_model_params["costheta0"]) + assert np.allclose(trainer.model.model_params["sigma"], np.log(2.0)) + + assert trainer.current["loss"] == None + assert trainer.current["epoch"] == 0 + assert trainer.current["step"] == 0 + assert trainer.current["device"] == "cpu" + assert trainer.current["warned_once"] == False + assert trainer.current["dataset_hash"] == None + assert trainer.current["data_dir"] == "test_run/datasets" + assert trainer.current["appending_to_previous_run"] == False + assert trainer.current["verbose"] == False + + # check dataset manifest + expected_dataset_manifest = { + "type": "ase", + "path": "../test_data/configs/Si_4.xyz", + "save": False, + "keys": {"energy": "Energy", "forces": "force"}, + "dynamic_loading": False, + "colabfit_dataset": { + "dataset_name": None, + "database_name": None, + "database_url": None, + }, + } + assert trainer.dataset_manifest == expected_dataset_manifest + + # check parameter settings + expected_parameter_manifest = [ + "A", + "B", + { + "sigma": { + "transform_name": "LogParameterTransform", + "value": 2.0, + "bounds": [[1.0, 10.0]], + } + }, + ] + + assert trainer.transform_manifest["parameter"] == expected_parameter_manifest + + expected_loss_manifest = { + "function": "MSE", + "weights": {"config": 1.0, "energy": 1.0, "forces": 1.0, "stress": None}, + "normalize_per_atom": True, + "loss_traj": False, + } + assert trainer.loss_manifest == expected_loss_manifest + + # dataset samples + assert trainer.dataset_sample_manifest["train_size"] == 3 + assert trainer.dataset_sample_manifest["val_size"] == 0 + assert trainer.dataset_sample_manifest["val_indices"] is None + assert isinstance(trainer.dataset_sample_manifest["train_indices"], np.ndarray) + + # check optimizer settings + expected_optimizer_manifest = { + "name": "L-BFGS-B", + "learning_rate": None, + "kwargs": {"tol": 1e-06}, + "epochs": 1000, + "stop_condition": None, + "num_workers": 2, + "batch_size": 1, + } + assert trainer.optimizer_manifest == expected_optimizer_manifest + + # dummy training + trainer.train() + # check if the trainer exited without any errors, check if .finished file is created + assert trainer.result.success + + # check if the kim model is saved, default folder is kim-model + trainer.save_kim_model() + + # assert + assert Path( + f"SW_StillingerWeber_trained_1985_Si__MO_405512056662_006.tar.gz" + ).exists() diff --git a/tests/trainer/test_lightning_trainer.py b/tests/trainer/test_lightning_trainer.py index 11be6ca4..0f206d37 100644 --- a/tests/trainer/test_lightning_trainer.py +++ b/tests/trainer/test_lightning_trainer.py @@ -58,8 +58,6 @@ def test_trainer(): } assert trainer.dataset_manifest == expected_dataset_manifest - assert trainer.model_manifest["type"] == "torch" - # check graph settings config_transform = { "name": "Graph", @@ -83,7 +81,6 @@ def test_trainer(): # check optimizer settings expected_optimizer_manifest = { - "provider": "torch", "name": "Adam", "learning_rate": 0.001, "kwargs": None,