Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 15, 2023
1 parent 3ffe4f2 commit cc0c30a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
18 changes: 15 additions & 3 deletions kgcnn/data/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import numpy as np
import tensorflow as tf
import pandas as pd
import os
from sklearn.model_selection import KFold
# import typing as t
from typing import Union, List, Callable, Dict, Optional
# from collections.abc import MutableSequence
Expand Down Expand Up @@ -705,10 +705,10 @@ def get_train_test_indices(self,
dataset belongs to that part of the split (train / test)
- The property is a list containing integer split indices, where each split index present within
that list implies that the corresponding dataset element is part of that particular split.
In this case the ``split_index`` parameter may also be a list of split indices that specifies
In this case the `split_index` parameter may also be a list of split indices that specify
for which of these split indices the train test index split is to be returned by this method.
The return value of this method is a list with the same length as the ``split_index`` parameter,
The return value of this method is a list with the same length as the `split_index` parameter,
which by default will be 1.
Args:
Expand Down Expand Up @@ -808,5 +808,17 @@ def get_multi_target_indices(self, graph_labels: str = "graph_labels", multi_tar
self.info("Labels '%s' in '%s' have shape '%s'." % (label_names, label_units, labels.shape))
return labels, label_names, label_units

def set_train_test_indices_k_fold(self, n_splits: int = 5, shuffle: bool = False, random_state: int = None,
train: str = "train", test: str = "test"):
kf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
for x in self:
x.set(train, [])
x.set(test, [])
for fold, (train_index, test_index) in enumerate(kf.split(np.expand_dims(np.arange(len(self)), axis=0))):
for i in train_index:
self[i].set(train, list(self[i].get(train)) + [fold])
for i in train_index:
self[i].set(test, list(self[i].get(test)) + [fold])


MemoryGeometricGraphDataset = MemoryGraphDataset
7 changes: 1 addition & 6 deletions training_core/train_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datetime import timedelta
from kgcnn.training_core.history import save_history_score, load_history_list
from kgcnn.metrics_core.metrics import ScaledMeanAbsoluteError, ScaledRootMeanSquaredError
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler as StandardLabelScaler
from kgcnn.utils_core.plots import plot_train_test_loss, plot_predict_true
from kgcnn.model.serial import deserialize as deserialize_model
Expand Down Expand Up @@ -72,10 +71,6 @@
data_unit=hyper["data"]["data_unit"] if "data_unit" in hyper["data"] else None
)

# Cross-validation via random KFold split form `sklearn.model_selection`. Other validation schemes could include
# stratified k-fold cross-validation for `MoleculeNetDataset` but is not implemented yet.
kf = KFold(**hyper["training"]["cross_validation"]["config"])

# Iterate over the cross-validation splits.
# Indices for train-test splits are stored in 'test_indices_list'.
execute_folds = args["fold"] if "execute_folds" not in hyper["training"] else hyper["training"]["execute_folds"]
Expand Down Expand Up @@ -143,7 +138,7 @@
save_pickle_file(hist.history, os.path.join(filepath, f"history{postfix_file}_fold_{current_split}.pickle"))

# Plot training- and test-loss vs epochs for all splits.
history_list = load_history_list(os.path.join(filepath, f"history{postfix_file}_fold_(i).pickle"), current_split)
history_list = load_history_list(os.path.join(filepath, f"history{postfix_file}_fold_(i).pickle"), current_split+1)
plot_train_test_loss(history_list, loss_name=None, val_loss_name=None,
model_name=hyper.model_name, data_unit=label_units, dataset_name=hyper.dataset_class,
filepath=filepath, file_name=f"loss{postfix_file}.png")
Expand Down

0 comments on commit cc0c30a

Please sign in to comment.