Skip to content

Commit

Permalink
Lightning trainer Comments #1
Browse files Browse the repository at this point in the history
  • Loading branch information
ipcamit committed Jun 23, 2024
1 parent a14718b commit 1a5812b
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 72 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ jobs:
python -m pip install --upgrade pip
python -m pip install .[test]
# install torch dependencies
python -m pip install .[torch]
# TODO, here, we install ptemcee from Yonatan's fork. See setup.py for details.
python -m pip uninstall --yes ptemcee
python -m pip install git+https://github.com/yonatank93/ptemcee.git@enhance_v1.0.0
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ examples/Si_training_set
libs/geodesicLMv1.1
tests/echo*
tests/fingerprints/
kliff.log
kim.log

tmp_*
*_kliff_trained/
tests/uq/*.pkl
tests/uq/*.json
tests/uq/kliff_saved_model

tests/trainer/test_run
tests/trainer/kliff.log

# dataset
Si_training_set_4_configs
Expand Down
27 changes: 20 additions & 7 deletions kliff/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import dill
import numpy as np
Expand Down Expand Up @@ -1180,13 +1180,22 @@ def check_properties_consistency(self, properties: List[str] = None):
logger.warning("No properties provided to check for consistency.")
return

property_list = list(copy.deepcopy(properties)) # make it mutable, if not
for config in self.configs:
for prop in property_list:
# property_list = list(copy.deepcopy(properties)) # make it mutable, if not
# for config in self.configs:
# for prop in property_list:
# try:
# getattr(config, prop)
# except ConfigurationError:
# property_list.remove(prop)
property_list = []
for prop in properties:
for config in self.configs:
try:
getattr(config, prop)
except ConfigurationError:
property_list.remove(prop)
break
else:
property_list.append(prop)

self.add_metadata({"consistent_properties": tuple(property_list)})
logger.info(
Expand All @@ -1195,7 +1204,8 @@ def check_properties_consistency(self, properties: List[str] = None):

@staticmethod
def get_manifest_checksum(
dataset_manifest: dict, transform_manifest: Optional[dict] = None
dataset_manifest: dict[str, Any],
transform_manifest: Optional[dict[str, Any]] = None,
) -> str:
"""
Get the checksum of the dataset manifest.
Expand Down Expand Up @@ -1273,7 +1283,10 @@ def get_dataset_from_manifest(dataset_manifest: dict) -> "Dataset":
and dataset_type != "path"
and dataset_type != "colabfit"
):
raise DatasetError(f"Dataset type {dataset_type} not supported.")
raise DatasetError(
f"Dataset type {dataset_type} not supported."
"Supported types are: ase, path, colabfit"
)
weights = dataset_manifest.get("weights", None)
if weights is not None:
if isinstance(weights, str):
Expand Down
39 changes: 22 additions & 17 deletions kliff/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, Union

import dill # TODO: include dill in requirements.txt
import numpy as np
import yaml
from loguru import logger
Expand Down Expand Up @@ -470,46 +469,46 @@ def setup_dataset_transforms(self):
property_name = property_to_transform.get("name", None)
if not property_name:
continue # it is probably an empty propery
transform_module_name = property_to_transform[property_name].get(
transform_class_name = property_to_transform[property_name].get(
"name", None
)
if not transform_module_name:
if not transform_class_name:
raise TrainerError(
"Property transform module name not provided."
)
property_transform_module = importlib.import_module(
f"kliff.transforms.property_transforms"
)
property_module = getattr(
property_transform_module, transform_module_name
TransformClass = getattr(
property_transform_module, transform_class_name
)
property_module = property_module(
TransformClass = TransformClass(
proprty_key=property_name,
**property_to_transform[property_name].get("kwargs", {}),
)
self.dataset = property_module(self.dataset)
self.property_transforms.append(property_module)
self.dataset = TransformClass(self.dataset)
self.property_transforms.append(TransformClass)

if configuration_transform:
configuration_module_name: Union[str, None] = (
configuration_class_name: Union[str, None] = (
configuration_transform.get("name", None)
)
if not configuration_module_name:
if not configuration_class_name:
logger.warning(
"Configuration transform module name not provided."
"Skipping configuration transform."
)
else:
configuration_module_name = (
configuration_class_name = (
"KIMDriverGraph"
if configuration_module_name.lower() == "graph"
else configuration_module_name
if configuration_class_name.lower() == "graph"
else configuration_class_name
)
configuration_transform_module = importlib.import_module(
f"kliff.transforms.configuration_transforms"
)
configuration_module = getattr(
configuration_transform_module, configuration_module_name
ConfigurationClass = getattr(
configuration_transform_module, configuration_class_name
)
kwargs: Union[dict, None] = configuration_transform.get(
"kwargs", None
Expand All @@ -518,11 +517,11 @@ def setup_dataset_transforms(self):
raise TrainerError(
"Configuration transform module options not provided."
)
configuration_module = configuration_module(
ConfigurationClass = ConfigurationClass(
**kwargs, copy_to_config=False
)

self.configuration_transform = configuration_module
self.configuration_transform = ConfigurationClass

def setup_model(self):
"""
Expand Down Expand Up @@ -766,3 +765,9 @@ class TrainerError(Exception):

def __init__(self, message):
super().__init__(message)


# TODO:
# 1. Test dataset
# 2. Stress
# 3. Get top k models
27 changes: 14 additions & 13 deletions kliff/trainer/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@

from pytorch_lightning.callbacks import EarlyStopping

from .torch_trainer_utils.lightning_checkpoints import (
from .torch_trainer_utils.lightning_utils import (
LossTrajectoryCallback,
SaveModelCallback,
)


class LightningTrainerWrapper(pl.LightningModule):
class LightningTrainer(pl.LightningModule):
"""
Wrapper class for Pytorch Lightning Module. This class is used to wrap the model and
the training loop in a Pytorch Lightning Module. It returns the energy, and forces
Expand All @@ -48,8 +48,7 @@ class LightningTrainerWrapper(pl.LightningModule):
Args:
model: Pytorch model to be trained
input_args: List of input arguments to the model. They are passed as dictionary
to the model, therefore order, and the name of the arguments should be the
same as in the model definition. Example: ["x", "coords", "edge_index0", "edge_index1" ...,"batch"]
to the model. Example: ["x", "coords", "edge_index0", "edge_index1" ...,"batch"]
ckpt_dir: Directory to save the checkpoints
device: Device to run the model on. Default is "cpu"
ema: Whether to use Exponential Moving Average. Default is True
Expand Down Expand Up @@ -112,14 +111,12 @@ def forward(self, batch: Any) -> Tuple[torch.Tensor, torch.Tensor]:
model_inputs = {k: batch[k] for k in self.input_args}
predicted_energy = self.model(**model_inputs)
(predicted_forces,) = torch.autograd.grad(
[predicted_energy],
predicted_energy.sum(),
batch["coords"],
create_graph=True, # TODO: grad against arbitrary param name
retain_graph=True,
grad_outputs=torch.ones_like(predicted_energy),
)
predicted_forces = scatter_add(predicted_forces, batch["images"], dim=0)
return predicted_energy, -predicted_forces
predicted_forces = -scatter_add(predicted_forces, batch["images"], dim=0)
return predicted_energy, predicted_forces

def training_step(self, batch, batch_idx):
"""
Expand Down Expand Up @@ -251,7 +248,7 @@ def validation_step(self, batch, batch_idx) -> Dict[str, torch.Tensor]:
class GNNLightningTrainer(Trainer):
"""
Trainer class for GNN models. This class is used to train GNN models using Pytorch
Lightning. It uses the `LightningTrainerWrapper` to wrap the model and the training
Lightning. It uses the `LightningTrainer` to wrap the model and the training
loop in a Pytorch Lightning Module. It also handles the dataloaders, loggers, and
callbacks.
"""
Expand All @@ -264,7 +261,7 @@ def __init__(self, manifest, model):
manifest: Dictionary containing the manifest for the trainer.
model: Pytorch model to be trained.
"""
self.pl_model: LightningTrainerWrapper = None
self.pl_model: LightningTrainer = None
self.data_module = None

super().__init__(manifest, model)
Expand All @@ -280,7 +277,7 @@ def __init__(self, manifest, model):

def setup_model(self):
"""
Set up the model for training. This function initializes the `LightningTrainerWrapper`
Set up the model for training. This function initializes the `LightningTrainer`
with the model, and the training parameters.
"""
# if dict has key ema, then set ema to True, decay to the dict value, else set ema false
Expand All @@ -292,7 +289,7 @@ def setup_model(self):

scheduler = self.optimizer_manifest.get("lr_scheduler", {})

self.pl_model = LightningTrainerWrapper(
self.pl_model = LightningTrainer(
model=self.model,
input_args=self.model_manifest["input_args"],
ckpt_dir=self.current["run_dir"],
Expand Down Expand Up @@ -512,3 +509,7 @@ def save_kim_model(self, path: str = "kim-model"):
def setup_optimizer(self):
# Not needed as Pytorch Lightning handles the optimizer
pass

def seed_all(self):
super().seed_all()
pl.seed_everything(self.current["seed"])
Loading

0 comments on commit 1a5812b

Please sign in to comment.