Skip to content

Commit

Permalink
Merge pull request #183 from ipcamit/kliff-master-v1-kim
Browse files Browse the repository at this point in the history
KIM Trainer and tests
  • Loading branch information
mjwen authored Jun 28, 2024
2 parents 016d6f7 + 544bdae commit 4db0f56
Show file tree
Hide file tree
Showing 12 changed files with 512 additions and 71 deletions.
10 changes: 5 additions & 5 deletions kliff/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ def add_weights(configurations: List[Configuration], path: Union[Path, str]):
"""
if path is None:
logger.info("No weights provided.")
logger.info("No explicit weights datafile provided.")
return

weights_data = np.genfromtxt(path, names=True)
Expand Down Expand Up @@ -1293,10 +1293,10 @@ def get_dataset_from_manifest(dataset_manifest: dict) -> "Dataset":
weights = Path(weights)
elif isinstance(weights, dict):
weights = Weight(
config_weight=weights.get("config", 0.0),
energy_weight=weights.get("energy", 0.0),
forces_weight=weights.get("forces", 0.0),
stress_weight=weights.get("stress", 0.0),
config_weight=weights.get("config", None),
energy_weight=weights.get("energy", None),
forces_weight=weights.get("forces", None),
stress_weight=weights.get("stress", None),
)
else:
raise DatasetError("Weights must be a path or a dictionary.")
Expand Down
16 changes: 8 additions & 8 deletions kliff/dataset/weight.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import numpy as np
from loguru import logger
Expand Down Expand Up @@ -27,10 +27,10 @@ class Weight:

def __init__(
self,
config_weight: float = 1.0,
energy_weight: float = 1.0,
forces_weight: float = 1.0,
stress_weight: float = 1.0,
config_weight: Union[float, None] = 1.0,
energy_weight: Union[float, None] = 1.0,
forces_weight: Union[float, None] = 1.0,
stress_weight: Union[float, None] = 1.0,
):
self._config_weight = config_weight
self._energy_weight = energy_weight
Expand Down Expand Up @@ -91,11 +91,11 @@ def _check_compute_flag(self, config):

# If the weight are really small, but not zero, then warn the user. Zero weight
# usually means that the property is used.
if config._energy is not None and np.all(ew < 1e-12):
if config._energy is not None and ew is not None and np.all(ew < 1e-12):
logger.warning(msg.format("energy", ew))
if config._forces is not None and np.all(fw < 1e-12):
if config._forces is not None and fw is not None and np.all(fw < 1e-12):
logger.warning(msg.format("forces", fw))
if config._stress is not None and np.all(sw < 1e-12):
if config._stress is not None and sw is not None and np.all(sw < 1e-12):
logger.warning(msg.format("stress", sw))


Expand Down
62 changes: 28 additions & 34 deletions kliff/models/kim.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,10 @@ def write_kim_model(self, path: Path = None):
Write out a KIM model that can be used directly with the kim-api.
This function typically write two files to `path`: (1) CMakeLists.txt, and (2)
a parameter file like A.model_params. `path` will be created if it does not exist.
a parameter file like A.params. `path` will be created if it does not exist.
Args:
path: Path to the a directory to store the model. If `None`, it is set to
path: Path to a directory to store the model. If `None`, it is set to
`./MODEL_NAME_kliff_trained`, where `MODEL_NAME` is the `model_name` that
provided at the initialization of this class.
Expand Down Expand Up @@ -696,7 +696,11 @@ def __call__(
return kim_ca_instance.results

@staticmethod
def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None):
def get_model_from_manifest(
model_manifest: dict,
param_manifest: dict = None,
is_model_tarfile: bool = False,
):
"""
Get the model from a configuration. If it is a valid KIM model, it will return
the KIMModel object. If it is a TorchML model, it will return the torch
Expand All @@ -710,7 +714,6 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None):
Example `model_manifest`:
```yaml
model:
model_type: kim # kim or torch
model_path: ./model.tar.gz # path to the model tarball
model_name: SW_StillingerWeber_1985_Si__MO_405512056662_006 # KIM model name, installed if missing
model_collection: "user"
Expand All @@ -735,12 +738,12 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None):
Args:
model_manifest: configuration object
param_manifest: parameter transformation configuration
is_model_tarfile: whether the model is a tarball
Returns:
Model object
"""
model_name: Union[None, str] = model_manifest.get("name", None)
model_type: Union[None, str] = model_manifest.get("type", None)
model_path: Union[None, str, Path] = model_manifest.get("path", None)
model_driver = KIMModel.get_model_driver_name(model_name)
model_collection = model_manifest.get("collection")
Expand All @@ -754,35 +757,9 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None):
f"Model driver {model_driver} not supported for KIMModel training."
)

# is model a tarball?
if model_path is not None:
model_path = Path(model_path)
if model_path.suffix == ".tar":
model_type = "tar"

# ensure model is installed
if model_type.lower() == "kim":
# is it a tar file?
is_model_installed = is_kim_model_installed(model_name)
if is_model_installed:
logger.info(f"Model {model_name} is already installed, continuing ...")
else:
logger.info(
f"Model {model_name} not installed on system, attempting to installing ..."
)
was_install_success = install_kim_model(model_name, model_collection)
if not was_install_success:
logger.error(
f"Model {model_name} not found in the KIM API collections. Please check the model name and try again."
)
raise KIMModelError(f"Model {model_name} not found.")
else:
logger.info(
f"Model {model_name} installed in {model_collection} collection."
)

elif model_type.lower() == "tar":
archive_content = tarfile.open(model_path + "/" + model_name)
if is_model_tarfile:
archive_content = tarfile.open(model_path)
model = archive_content.getnames()[0]
archive_content.extractall(model_path)
subprocess.run(
Expand All @@ -798,8 +775,25 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None):
logger.info(
f"Tarball Model {model} installed in {model_collection} collection."
)

is_model_installed = is_kim_model_installed(model_name)

if is_model_installed:
logger.info(f"Model {model_name} is already installed, continuing ...")
else:
raise KIMModelError(f"Model type {model_type} not supported.")
logger.info(
f"Model {model_name} not installed on system, attempting to installing ..."
)
was_install_success = install_kim_model(model_name, model_collection)
if not was_install_success:
logger.error(
f"Model {model_name} not found in the KIM API collections. Please check the model name and try again."
)
raise KIMModelError(f"Model {model_name} not found.")
else:
logger.info(
f"Model {model_name} installed in {model_collection} collection."
)

model = KIMModel(model_name)

Expand Down
2 changes: 1 addition & 1 deletion kliff/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base_trainer import Trainer
from .kim_trainer import KIMTrainer
from .lightning_trainer import GNNLightningTrainer

# from .kim_trainer import KIMTrainer
# from .torch_trainer import DNNTrainer
17 changes: 17 additions & 0 deletions kliff/trainer/_kim_loss_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np


def MSE_loss(
predictions: np.ndarray,
targets: np.ndarray,
) -> np.ndarray:
r"""
Compute the mean squared error (MSE) of the residuals.
Args:
Returns:
The MSE of the residuals.
"""
residuals = predictions - targets
return np.mean(residuals**2)
56 changes: 40 additions & 16 deletions kliff/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from loguru import logger

from kliff.dataset import Dataset
from kliff.dataset.weight import Weight

if TYPE_CHECKING:
from kliff.transforms.configuration_transforms import ConfigurationTransform
Expand Down Expand Up @@ -104,9 +105,9 @@ def __init__(self, training_manifest: dict, model=None):

# model variables
self.model_manifest: dict = {
"type": "kim",
"name": None,
"path": None,
"collection": "user",
}
self.model: Callable = model

Expand Down Expand Up @@ -144,7 +145,6 @@ def __init__(self, training_manifest: dict, model=None):
}

self.optimizer_manifest: dict = {
"provider": "scipy",
"name": None,
"learning_rate": None,
"kwargs": None,
Expand All @@ -170,6 +170,7 @@ def __init__(self, training_manifest: dict, model=None):
self.export_manifest: dict = {
"model_name": None,
"model_path": None,
"generate_tarball": False,
}

# state variables
Expand Down Expand Up @@ -426,12 +427,29 @@ def setup_dataset(self):
TODO: ColabFit integration for extreme scale datasets.
"""
dataset_module_manifest = deepcopy(self.dataset_manifest)
dataset_module_manifest["weights"] = self.loss_manifest["weights"]
# dataset_module_manifest["weights"] = self.loss_manifest["weights"]

weights = self.loss_manifest["weights"]

if weights is not None:
if isinstance(weights, str):
weights = Path(weights)
elif isinstance(weights, dict):
weights = Weight(
config_weight=weights.get("config", None),
energy_weight=weights.get("energy", None),
forces_weight=weights.get("forces", None),
stress_weight=weights.get("stress", None),
)
else:
raise TrainerError("Weights must be a path or a dictionary.")

dataset_list = _parallel_read(
self.dataset_manifest["path"],
num_chunks=self.optimizer_manifest["num_workers"],
energy_key=self.dataset_manifest.get("keys", {}).get("energy", "energy"),
forces_key=self.dataset_manifest.get("keys", {}).get("forces", "forces"),
weights=weights,
)
self.dataset = deepcopy(dataset_list[0])
for ds in dataset_list[1:]:
Expand Down Expand Up @@ -637,17 +655,18 @@ def setup_dataset_split(self):
fmt="%d",
)

if isinstance(val_indices, str):
self.dataset_sample_manifest["indices_files"]["val"] = val_indices
else:
self.dataset_sample_manifest["indices_files"][
"val"
] = f"{self.current['run_dir']}/val_indices.txt"
np.savetxt(
self.dataset_sample_manifest["indices_files"]["val"],
val_indices,
fmt="%d",
)
if val_size > 0:
if isinstance(val_indices, str):
self.dataset_sample_manifest["indices_files"]["val"] = val_indices
else:
self.dataset_sample_manifest["indices_files"][
"val"
] = f"{self.current['run_dir']}/val_indices.txt"
np.savetxt(
self.dataset_sample_manifest["indices_files"]["val"],
val_indices,
fmt="%d",
)

def loss(self, *args, **kwargs):
raise TrainerError("loss not implemented.")
Expand Down Expand Up @@ -708,6 +727,7 @@ def _parallel_read(
num_chunks=None,
energy_key="Energy",
forces_key="forces",
weights=None,
) -> List[Dataset]:
"""
Read and transform frames in parallel. Returns n_chunks of datasets.
Expand Down Expand Up @@ -742,18 +762,22 @@ def _parallel_read(
with multiprocessing.Pool(processes=num_chunks) as pool:
ds = pool.starmap(
_read_frames,
[(file_path, start, end, energy_key, forces_key) for start, end in chunks],
[
(file_path, start, end, energy_key, forces_key, weights)
for start, end in chunks
],
)

return ds


def _read_frames(file_path, start, end, energy_key, forces_key):
def _read_frames(file_path, start, end, energy_key, forces_key, weights):
ds = Dataset.from_ase(
path=file_path,
energy_key=energy_key,
forces_key=forces_key,
slices=slice(start, end),
weight=weights,
)
return ds

Expand Down
Loading

0 comments on commit 4db0f56

Please sign in to comment.