diff --git a/kgcnn/literature/GNNExplain.py b/kgcnn/literature/GNNExplain.py index e1b2eeb5..290d97f7 100644 --- a/kgcnn/literature/GNNExplain.py +++ b/kgcnn/literature/GNNExplain.py @@ -1,11 +1,201 @@ +""" +"Ying et al. - GNNExplainer: Generating Explanations for Graph Neural Networks" + +**Changelog** + +??.??.2022 - Initial implementation + +30.01.2023 - Added the class "GnnExplainer" which supports RaggedTensors and can thus generate multiple +explanations at once, greatly improving time efficiency for explaining large batches of predictions. However +the new class does not implement visualization of the explanations. This will have to be realized on a +higher abstraction level. +""" +import time +import typing as t + +import numpy as np import tensorflow as tf ks = tf.keras +from kgcnn.xai.base import ImportanceExplanationMethod + # Keep track of model version from commit date in literature. # To be updated if model is changed in a significant way. __model_version__ = "2022.05.31" +# == REDUCED, RAGGED TENSOR IMPLEMENTATION == + +class GnnExplainer(ImportanceExplanationMethod): + """ + Implementation of "ImportanceExplanationMethod", which means that calling an instance of this class + given a model, a ragged input tensor and output predictions, it should return the corresponding + node and edge importance tensors, which provide an explanation by assigning each node and edge of the + input graphs with a 0-1 importance value. + + By the nature of the base idea behind GNNExplainer, the number of explanations produced has to be equal + to the number of prediction targets that are generated by the model. Each target will receive its own + explanation. + """ + def __init__(self, + channels: int, + epochs: int = 100, + learning_rate: float = 0.01, + node_sparsity_factor: float = 0.1, + edge_sparsity_factor: float = 0.1, + log_step: int = 10, + verbose: bool = True): + super(GnnExplainer, self).__init__(channels=channels) + self.epochs = epochs + self.learning_rate = learning_rate + self.log_step = log_step + self.verbose = verbose + self.node_sparsity_factor = node_sparsity_factor + self.edge_sparsity_factor = edge_sparsity_factor + + def __call__(self, + model: ks.models.Model, + x: t.Tuple[tf.RaggedTensor, tf.RaggedTensor, tf.RaggedTensor], + y: np.ndarray): + """ + Given a model, the input tensor and the output array, this method will return a tuple of two + ragged tensors, which represent the node importances and the edge importances. + + Beware, that this method executes an entire training process and may take some time. + + Reference of tensor shapes. [Brackets] indicate ragged dimension + - V: Number of nodes in graph + - E: Number of edges in graph + - K: Number of explanation channels given in constructor. This has to be equal to the number of + prediction targets specified in the constructor. + - N: Number of node attributes + - M: Number of edge attributes + - B: batch size + + Args: + x: A tuple (node_input, edge_input, edge_indices) of 3 RaggedTensors + - node_input: Shape ([B], [V], N) + - edge_input: Shape ([B], [E], M) + - edge_indices: Shape ([B], [E], 2) + y: A numpy array of shape (B, K) + model: Any compatible keras model, which means any model which accepts the previously described + input tensors and returns output similar to the previously described output tensor. + + Returns: + A tuple (node_importances, edge_importances) of RaggedTensors. + - node_importances: Shape ([B], [V], K) + - edge_importances: Shape ([B], [E], K) + """ + # Generally the idea of the implementation is that we use the node_input and edge_input tensors as + # templates to generate the mask variable tensors, which match the graph dimensions but differ in + # the final dimension, which instead of the node / edge features we will use to represent the + # number of importance channels (== number of prediction targets). + + node_input, edge_input, edge_indices = x + + # Here we reduce away the last dimension of node and edge input to get just the ragged graph sizes + # But we run into a problem here with multiple channels: We cant actually use the last dimension to + # represent the number of different explanation channels. Instead, we do a workaround here where for + # each channel we extend the batch dimension. Aka we assume that all the different channels are just + # additional graphs to be treated like the others. The reason why we have to do it like that is + # because later on we need to multiply the masks with the inputs! + node_mask_single = tf.reduce_mean(tf.ones_like(node_input), axis=-1, keepdims=True) + node_mask_ragged = tf.concat([node_mask_single for _ in range(self.channels)], axis=0) + node_mask_variables = tf.Variable(node_mask_ragged.flat_values, trainable=True, dtype=tf.float64) + + edge_mask_single = tf.reduce_mean(tf.ones_like(edge_input), axis=-1, keepdims=True) + edge_mask_ragged = tf.concat([edge_mask_single for _ in range(self.channels)], axis=0) + edge_mask_variables = tf.Variable(edge_mask_ragged.flat_values, trainable=True, dtype=tf.float64) + + optimizer = ks.optimizers.Nadam(learning_rate=self.learning_rate) + + # This is a logical extension of what was previously described. Since we treat the different + # explanation channels as just a batch extension, we have to modify the input values and the output + # values accordingly so that they have the same batch size so to say. Naturally we simply have to + # duplicate the values. + x_extended = ( + tf.concat([node_input for _ in range(self.channels)], axis=0), + tf.concat([edge_input for _ in range(self.channels)], axis=0), + tf.concat([edge_indices for _ in range(self.channels)], axis=0), + ) + y_extended = [] + for c in range(self.channels): + y_mod = np.zeros_like(y) + y_mod[:, c] = y[:, c] + y_extended.append(y_mod) + + y_extended = np.concatenate(y_extended) + + start_time = time.time() + for epoch in range(self.epochs): + + with tf.GradientTape() as tape: + node_mask = tf.RaggedTensor.from_nested_row_splits( + node_mask_variables, + nested_row_splits=node_mask_ragged.nested_row_splits + ) + + edge_mask = tf.RaggedTensor.from_nested_row_splits( + edge_mask_variables, + nested_row_splits=edge_mask_ragged.nested_row_splits + ) + + out = model([ + x_extended[0] * node_mask, + x_extended[1] * edge_mask, + x_extended[2] + ]) + + # The loss can basically be summerized as: We try to find the smallest subset of nodes and + # edges in the input, which will cause the network to get as close as possible to it's + # original prediction! + loss = tf.cast(tf.reduce_mean(tf.square(y_extended - out)), dtype=tf.float64) + # Important detail: The reduce_sum here reduces over all the nodes / edges and is necessary! + loss += self.node_sparsity_factor * tf.reduce_mean(tf.reduce_sum(tf.abs(node_mask), axis=1)) + loss += self.edge_sparsity_factor * tf.reduce_mean(tf.reduce_sum(tf.abs(edge_mask), axis=1)) + + trainable_vars = [node_mask_variables, edge_mask_variables] + gradients = tape.gradient(loss, trainable_vars) + optimizer.apply_gradients(zip(gradients, trainable_vars)) + + if self.verbose and epoch % self.log_step == 0: + print(f' * epoch ({epoch}/{self.epochs}) ' + f' - loss: {loss}' + f' - elapsed time: {time.time()-start_time:.2f} seconds') + + # For the training we had to treat the different explanation channels as a batch extension. As per + # the interface we need to return the importances however such that the different explanation + # channels are organized into the third dimension of the tensors. + + # Sadly this does not work in a more direct fashion. We get the number of elements of nodes and + # edges that belong to one explanation channel. Iterate in chunks of that size and turn each of + # those chunks into it's own explanation respectively. At the end we concatenate all of them in + # the 3rd dimension to produce the desired result. + num_elements_node = node_mask_single.flat_values.shape[0] + num_elements_edge = edge_mask_single.flat_values.shape[0] + node_importances_list = [] + edge_importances_list = [] + for c in range(self.channels): + node_importances_part = tf.RaggedTensor.from_nested_row_splits( + node_mask_variables[c*num_elements_node:(c+1)*num_elements_node, :], + node_mask_single.nested_row_splits + ) + node_importances_list.append(node_importances_part) + + edge_importances_part = tf.RaggedTensor.from_nested_row_splits( + edge_mask_variables[c*num_elements_edge:(c+1)*num_elements_edge, :], + edge_mask_single.nested_row_splits + ) + edge_importances_list.append(edge_importances_part) + + return ( + tf.concat(node_importances_list, axis=-1), + tf.concat(edge_importances_list, axis=-1) + ) + + +# == ORIGINAL IMPLEMENTATION == + class GNNInterface: """An interface class which should be implemented by a Graph Neural Network (GNN) model to make it explainable. This class is just an interface, which is used by the `GNNExplainer` and should be implemented in a subclass. diff --git a/kgcnn/xai/base.py b/kgcnn/xai/base.py index c44d378e..a705b83b 100644 --- a/kgcnn/xai/base.py +++ b/kgcnn/xai/base.py @@ -1,6 +1,10 @@ import typing as t +import numpy as np import tensorflow as tf +import tensorflow.keras as ks + +from kgcnn.data.utils import ragged_tensor_from_nested_numpy class AbstractExplanationMixin: @@ -20,3 +24,55 @@ def explain_importances(self, **kwargs ) -> t.Tuple[tf.RaggedTensor, tf.RaggedTensor]: raise NotImplementedError + + +class AbstractExplanationMethod: + + def __call__(self, model, x, y): + raise NotImplementedError + + +class ImportanceExplanationMethod(AbstractExplanationMethod): + + def __init__(self, + channels: int): + self.channels = channels + + def __call__(self, + model: ks.models.Model, + x: tf.Tensor, + y: tf.Tensor + ) -> t.Tuple[tf.Tensor, tf.Tensor]: + raise NotImplementedError + + +class MockImportanceExplanationMethod(ImportanceExplanationMethod): + """ + This is a mock implementation of "ImportanceExplanationMethod". It is purely for testing purposes. + Using this method will result in randomly generated importance values for nodes and edges. + """ + def __init__(self, channels): + super(MockImportanceExplanationMethod, self).__init__(channels=channels) + + def __call__(self, + model: ks.models.Model, + x: t.Tuple[tf.Tensor], + y: t.Tuple[tf.Tensor], + ) -> t.Tuple[tf.Tensor, tf.Tensor]: + node_input, edge_input, _ = x + + # Im sure you could probably do this in tensorflow directly, but I am just going to go the numpy + # route here because that's just easier. + node_input = node_input.numpy() + edge_input = edge_input.numpy() + + node_importances = [np.random.uniform(0, 1, size=(v.shape[0], self.channels)) + for v in node_input] + edge_importances = [np.random.uniform(0, 1, size=(v.shape[0], self.channels)) + for v in edge_input] + + return ( + ragged_tensor_from_nested_numpy(node_importances), + ragged_tensor_from_nested_numpy(edge_importances) + ) + diff --git a/kgcnn/xai/testing.py b/kgcnn/xai/testing.py new file mode 100644 index 00000000..833e2777 --- /dev/null +++ b/kgcnn/xai/testing.py @@ -0,0 +1,130 @@ +import random +import typing as t + +import numpy as np +import tensorflow as tf +import tensorflow.keras as ks + +from kgcnn.layers.conv.gat_conv import AttentionHeadGATV2 +from kgcnn.layers.modules import DenseEmbedding +from kgcnn.layers.pooling import PoolingGlobalEdges +from kgcnn.data.utils import ragged_tensor_from_nested_numpy + + +# This is a very simple mock implementation, because to test the explanation methods we need some sort +# of a model as basis and this model will act as such. +class Model(ks.models.Model): + + def __init__(self, + num_targets: int = 1): + super(Model, self).__init__() + self.conv_layers = [ + AttentionHeadGATV2(units=64, use_edge_features=True, use_bias=True), + ] + self.lay_pooling = PoolingGlobalEdges(pooling_method='sum') + self.lay_dense = DenseEmbedding(units=num_targets, activation='linear') + + def call(self, inputs, training=False): + node_input, edge_input, edge_index_input = inputs + x = node_input + for lay in self.conv_layers: + x = lay([x, edge_input, edge_index_input]) + + pooled = self.lay_pooling(x) + out = self.lay_dense(pooled) + return out + + +class MockContext: + + def __init__(self, + num_elements: int = 10, + num_targets: int = 1, + epochs: int = 10, + batch_size: int = 2): + self.num_elements = num_elements + self.num_targets = num_targets + self.epochs = epochs + self.batch_size = batch_size + + self.model = Model(num_targets=num_targets) + self.x = None + self.y = None + + def generate_graph(self, + num_nodes: int, + num_node_attributes: int = 3, + num_edge_attributes: int = 1): + remaining = list(range(num_nodes)) + random.shuffle(remaining) + inserted = [remaining.pop(0)] + node_attributes = [[random.random() for _ in range(num_node_attributes)] for _ in range(num_nodes)] + edge_indices = [] + edge_attributes = [] + while len(remaining) != 0: + i = remaining.pop(0) + j = random.choice(inserted) + inserted.append(i) + + edge_indices += [[i, j], [j, i]] + edge_attribute = [1 for _ in range(num_edge_attributes)] + edge_attributes += [edge_attribute, edge_attribute] + + return ( + np.array(node_attributes, dtype=float), + np.array(edge_attributes, dtype=float), + np.array(edge_indices, dtype=int) + ) + + def generate_data(self): + node_attributes_list = [] + edge_attributes_list = [] + edge_indices_list = [] + targets_list = [] + for i in range(self.num_elements): + num_nodes = random.randint(5, 20) + node_attributes, edge_attributes, edge_indices = self.generate_graph(num_nodes) + node_attributes_list.append(node_attributes) + edge_attributes_list.append(edge_attributes) + edge_indices_list.append(edge_indices) + + # The target value we will actually determine deterministically here so that our network + # actually has a chance to learn anything + target = np.sum(node_attributes) + targets = [target for _ in range(self.num_targets)] + targets_list.append(targets) + + self.x = ( + ragged_tensor_from_nested_numpy(node_attributes_list), + ragged_tensor_from_nested_numpy(edge_attributes_list), + ragged_tensor_from_nested_numpy(edge_indices_list) + ) + + self.y = ( + np.array(targets_list, dtype=float) + ) + + def __enter__(self): + # This method will generate random input and output data and thus populate the internal attributes + # self.x and self.y + self.generate_data() + + # Using these we will train our mock model for a few very brief epochs. + self.model.compile( + loss=ks.losses.mean_squared_error, + metrics=ks.metrics.mean_squared_error, + run_eagerly=False, + optimizer=ks.optimizers.Nadam(learning_rate=0.01), + ) + hist = self.model.fit( + self.x, self.y, + batch_size=self.batch_size, + epochs=self.epochs, + verbose=0, + ) + self.history = hist.history + + return self + + def __exit__(self, *args, **kwargs): + pass diff --git a/test/test_literature_gnnexplain.py b/test/test_literature_gnnexplain.py new file mode 100644 index 00000000..b891dc85 --- /dev/null +++ b/test/test_literature_gnnexplain.py @@ -0,0 +1,41 @@ +import unittest + +import tensorflow as tf + +from kgcnn.xai.testing import MockContext +from kgcnn.literature.GNNExplain import GnnExplainer + + +class TestGnnExplainer(unittest.TestCase): + + def test_basically_works(self): + num_targets = 1 + with MockContext(num_targets=num_targets) as mock: + gnn_explainer = GnnExplainer( + channels=num_targets, + verbose=True + ) + node_importances, edge_importances = gnn_explainer( + model=mock.model, + x=mock.x, + y=mock.y, + ) + assert isinstance(node_importances, tf.RaggedTensor) + assert isinstance(edge_importances, tf.RaggedTensor) + + def test_multiple_targets_works(self): + num_targets = 2 + with MockContext(num_targets=num_targets) as mock: + gnn_explainer = GnnExplainer( + channels=num_targets, + verbose=True + ) + node_importances, edge_importances = gnn_explainer( + model=mock.model, + x=mock.x, + y=mock.y, + ) + assert isinstance(node_importances, tf.RaggedTensor) + assert isinstance(edge_importances, tf.RaggedTensor) + + diff --git a/test/test_xai_base.py b/test/test_xai_base.py new file mode 100644 index 00000000..9c20de68 --- /dev/null +++ b/test/test_xai_base.py @@ -0,0 +1,39 @@ +import unittest + +import numpy as np +import tensorflow as tf +import tensorflow.keras as ks + +from kgcnn.xai.testing import MockContext +from kgcnn.xai.base import MockImportanceExplanationMethod + + +# == UNIT TESTS == + +class TestMockImportanceExplanationMethod(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.mock = MockContext() + cls.mock.__enter__() + + @classmethod + def tearDownClass(cls) -> None: + cls.mock.__exit__(None, None, None) + + def test_basically_works(self): + channels = 3 + xai_instance = MockImportanceExplanationMethod(channels=channels) + node_importances, edge_importances = xai_instance( + self.mock.model, + self.mock.x, + self.mock.y + ) + assert isinstance(node_importances, tf.RaggedTensor) + assert isinstance(edge_importances, tf.RaggedTensor) + + node_importances = node_importances.numpy() + edge_importances = edge_importances.numpy() + for v, w in zip(node_importances, edge_importances): + assert v.shape[-1] == channels + assert w.shape[-1] == channels diff --git a/test/test_xai_testing.py b/test/test_xai_testing.py new file mode 100644 index 00000000..73982d83 --- /dev/null +++ b/test/test_xai_testing.py @@ -0,0 +1,33 @@ +import unittest + +import numpy as np +import tensorflow as tf +import tensorflow.keras as ks + +from kgcnn.xai.testing import MockContext + + +class TestMockContext(unittest.TestCase): + + def test_basically_works(self): + num_elements = 10 + num_targets = 2 + with MockContext(num_elements=num_elements, num_targets=num_targets) as mock: + assert isinstance(mock, MockContext) + + assert isinstance(mock.model, ks.models.Model) + assert mock.model.built + + assert isinstance(mock.x, tuple) + assert len(mock.x) == 3 + for element in mock.x: + print(element.shape) + assert isinstance(element, tf.RaggedTensor) + + assert isinstance(mock.y, tuple) + assert len(mock.y) == 1 + + targets = mock.y[0] + assert isinstance(targets, np.ndarray) + assert len(targets.shape) == 2 + assert targets.shape[-1] == num_targets diff --git a/training/hyper/hyper_vgd_mock.py b/training/hyper/hyper_vgd_mock.py index 5d5ce478..617a991d 100644 --- a/training/hyper/hyper_vgd_mock.py +++ b/training/hyper/hyper_vgd_mock.py @@ -57,4 +57,72 @@ "kgcnn_version": "2.2.0" } }, + "GCN": { + "explanation": { + "channels": 2, + "gt_suffix": None, + }, + "xai_methods": { + "Mock": { + "class_name": "MockImportanceExplanationMethod", + "module_name": "kgcnn.xai.base", + "config": {} + }, + "GnnExplainer": { + "class_name": "GnnExplainer", + "module_name": "kgcnn.literature.GNNExplain", + "config": { + "learning_rate": 0.01, + "epochs": 250, + "node_sparsity_factor": 0.1, + "edge_sparsity_factor": 0.1, + } + } + }, + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.GCN", + "config": { + "name": "GCN", + 'inputs': [{'shape': (None, 1), 'name': "node_attributes", 'dtype': 'float32', 'ragged': True}, + {'shape': (None, 1), 'name': "edge_attributes", 'dtype': 'float32', 'ragged': True}, + {'shape': (None, 2), 'name': "edge_indices", 'dtype': 'int64', 'ragged': True}], + "gcn_args": {"units": 32, "use_bias": True, "activation": "relu"}, + "depth": 3, "verbose": 10, + "output_embedding": "graph", + "output_mlp": {"use_bias": [True, True, False], "units": [32, 16, 2], + "activation": ["relu", "relu", "softmax"]}, + } + }, + "training": { + "fit": { + "batch_size": 32, + "epochs": 100, + "validation_freq": 10, + "verbose": 2, + }, + "compile": { + "optimizer": {"class_name": "Nadam", "config": {"lr": 1e-02}}, + "loss": "categorical_crossentropy", + "metrics": ["categorical_accuracy"], + }, + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "scaler": {"class_name": "StandardScaler", "config": {"with_std": True, "with_mean": True, "copy": True}}, + }, + "data": { + "dataset": { + "class_name": "VgdMockDataset", + "module_name": "kgcnn.data.datasets.VgdMockDataset", + "config": {}, + "methods": [] + }, + "data_unit": "" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "2.0.3" + } + } } \ No newline at end of file diff --git a/training/results/README.md b/training/results/README.md index cf24bf16..a8fc503b 100644 --- a/training/results/README.md +++ b/training/results/README.md @@ -290,9 +290,10 @@ Energies and forces for molecular dynamics trajectories. All geometries in A, en Synthetic classification dataset containing 100 small, randomly generated graphs, where half of them were seeded with a triangular subgraph motif, which is the explanation ground truth for the target class distinction. -| model | kgcnn | epochs | Categorical Accuracy | Node AUC | Edge AUC | -|:--------|:--------|---------:|:-----------------------|:-----------------------|:-----------------------| -| MEGAN | 2.2.0 | 100 | **0.9400 ± 0.0490** | **0.8873 ± 0.0250** | **0.9518 ± 0.0241** | +| model | kgcnn | epochs | Categorical Accuracy | Node AUC | Edge AUC | +|:-----------------|:--------|---------:|:-----------------------|:-----------------------|:-----------------------| +| GCN_GnnExplainer | 2.2.1 | 100 | 0.8700 ± 0.1122 | 0.7621 ± 0.0357 | 0.6051 ± 0.0416 | +| MEGAN | 2.2.0 | 100 | **0.9400 ± 0.0490** | **0.8873 ± 0.0250** | **0.9518 ± 0.0241** | ## VgdRbMotifsDataset diff --git a/training/results/VgdMockDataset/GCN_GnnExplainer/GCN_VgdMockDataset_score.yaml b/training/results/VgdMockDataset/GCN_GnnExplainer/GCN_VgdMockDataset_score.yaml new file mode 100644 index 00000000..ffcd1027 --- /dev/null +++ b/training/results/VgdMockDataset/GCN_GnnExplainer/GCN_VgdMockDataset_score.yaml @@ -0,0 +1,159 @@ +categorical_accuracy: +- 0.737500011920929 +- 0.925000011920929 +- 0.949999988079071 +- 0.987500011920929 +- 0.9624999761581421 +data_unit: '' +date_time: '2023-01-30 11:22:32' +edge_auc: +- 0.0 +- 0.0 +- 0.0 +- 0.0 +- 0.0 +epochs: +- 100 +- 100 +- 100 +- 100 +- 100 +execute_folds: null +kgcnn_version: 2.2.1 +loss: +- 0.4342142939567566 +- 0.16217032074928284 +- 0.09488578140735626 +- 0.04377737268805504 +- 0.08145208656787872 +max_categorical_accuracy: +- 0.9750000238418579 +- 0.987500011920929 +- 0.9750000238418579 +- 1.0 +- 1.0 +max_edge_auc: +- 0.0 +- 0.0 +- 0.0 +- 0.0 +- 0.0 +max_loss: +- 4.44147253036499 +- 2.8822808265686035 +- 2.0316762924194336 +- 1.1726633310317993 +- 1.1836830377578735 +max_node_auc: +- 0.0 +- 0.0 +- 0.0 +- 0.0 +- 0.0 +max_val_categorical_accuracy: +- 1.0 +- 0.949999988079071 +- 1.0 +- 0.949999988079071 +- 1.0 +max_val_edge_auc: +- 0.662757487539147 +- 0.6403508771929826 +- 0.5727952453987729 +- 0.5992169540229886 +- 0.5503513563991222 +max_val_loss: +- 0.8895169496536255 +- 1.0942341089248657 +- 0.4583742618560791 +- 0.991496205329895 +- 0.45864877104759216 +max_val_node_auc: +- 0.8319827120475419 +- 0.7459549878345498 +- 0.74553264604811 +- 0.7320391414141414 +- 0.7547622735016262 +min_categorical_accuracy: +- 0.4749999940395355 +- 0.42500001192092896 +- 0.4749999940395355 +- 0.375 +- 0.512499988079071 +min_edge_auc: +- 0.0 +- 0.0 +- 0.0 +- 0.0 +- 0.0 +min_loss: +- 0.06907949596643448 +- 0.05166038125753403 +- 0.07260012626647949 +- 0.0417063906788826 +- 0.04308139532804489 +min_node_auc: +- 0.0 +- 0.0 +- 0.0 +- 0.0 +- 0.0 +min_val_categorical_accuracy: +- 0.550000011920929 +- 0.5 +- 0.75 +- 0.6499999761581421 +- 0.8500000238418579 +min_val_edge_auc: +- 0.662757487539147 +- 0.6403508771929826 +- 0.5727952453987729 +- 0.5992169540229886 +- 0.5503513563991222 +min_val_loss: +- 0.08486564457416534 +- 0.10219880193471909 +- 0.017678719013929367 +- 0.09296995401382446 +- 0.0624590627849102 +min_val_node_auc: +- 0.8319827120475419 +- 0.7459549878345498 +- 0.74553264604811 +- 0.7320391414141414 +- 0.7547622735016262 +model_class: make_model +model_name: GCN_GnnExplainer +model_version: '' +multi_target_indices: null +node_auc: +- 0.0 +- 0.0 +- 0.0 +- 0.0 +- 0.0 +number_histories: 5 +val_categorical_accuracy: +- 0.6499999761581421 +- 0.8999999761581421 +- 0.949999988079071 +- 0.8999999761581421 +- 0.949999988079071 +val_edge_auc: +- 0.662757487539147 +- 0.6403508771929826 +- 0.5727952453987729 +- 0.5992169540229886 +- 0.5503513563991222 +val_loss: +- 0.509390115737915 +- 0.16633762419223785 +- 0.06243344023823738 +- 0.09296995401382446 +- 0.14829257130622864 +val_node_auc: +- 0.8319827120475419 +- 0.7459549878345498 +- 0.74553264604811 +- 0.7320391414141414 +- 0.7547622735016262 diff --git a/training/results/VgdMockDataset/GCN_GnnExplainer/GCN_hyper.json b/training/results/VgdMockDataset/GCN_GnnExplainer/GCN_hyper.json new file mode 100644 index 00000000..01315f08 --- /dev/null +++ b/training/results/VgdMockDataset/GCN_GnnExplainer/GCN_hyper.json @@ -0,0 +1 @@ +{"explanation": {"channels": 2, "gt_suffix": null}, "xai_methods": {"Mock": {"class_name": "MockImportanceExplanationMethod", "module_name": "kgcnn.xai.base", "config": {}}, "GnnExplainer": {"class_name": "GnnExplainer", "module_name": "kgcnn.literature.GNNExplain", "config": {"learning_rate": 0.01, "epochs": 250, "node_sparsity_factor": 0.1, "edge_sparsity_factor": 0.1}}}, "model": {"class_name": "make_model", "module_name": "kgcnn.literature.GCN", "config": {"name": "GCN", "inputs": [{"shape": [null, 1], "name": "node_attributes", "dtype": "float32", "ragged": true}, {"shape": [null, 1], "name": "edge_attributes", "dtype": "float32", "ragged": true}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64", "ragged": true}], "gcn_args": {"units": 32, "use_bias": true, "activation": "relu"}, "depth": 3, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [32, 16, 2], "activation": ["relu", "relu", "softmax"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 100, "validation_freq": 10, "verbose": 2}, "compile": {"optimizer": {"class_name": "Nadam", "config": {"lr": 0.01}}, "loss": "categorical_crossentropy", "metrics": ["categorical_accuracy"]}, "cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "data": {"dataset": {"class_name": "VgdMockDataset", "module_name": "kgcnn.data.datasets.VgdMockDataset", "config": {}, "methods": []}, "data_unit": ""}, "info": {"postfix": "_GnnExplainer", "postfix_file": "", "kgcnn_version": "2.0.3"}} \ No newline at end of file diff --git a/training/train_visual_graph_dataset.py b/training/train_visual_graph_dataset.py index 87674d10..f0a601b3 100644 --- a/training/train_visual_graph_dataset.py +++ b/training/train_visual_graph_dataset.py @@ -8,6 +8,7 @@ import click import numpy as np +import matplotlib.colors as mcolors from sklearn.model_selection import KFold from sklearn.preprocessing import minmax_scale from sklearn.metrics import roc_auc_score @@ -64,7 +65,6 @@ def main(model: str, input graph (nodes & edges) which determine how important that respective element was for the outcome of the target value prediction. """ - model_name = model # We are doing all the kgcnn imports only now because importing kgcnn will start the tensorflow runtime # which may take a few seconds and which may print a few error messages. In principle that is not a bad @@ -78,13 +78,15 @@ def main(model: str, from kgcnn.utils.plots import plot_train_test_loss from kgcnn.training.history import save_history_score from kgcnn.xai.utils import flatten_importances_list + from kgcnn.xai.base import AbstractExplanationMethod # Tensorflow warnings are really annoying, so we only show them if the flag is explicitly set if not show_warnings: warnings.filterwarnings("ignore") os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - echo_info(f'attempting training of model "{model}" and XAI method "{xai_method}" on dataset "{dataset}"') + echo_info(f'attempting training of model "{model}" and XAI method "{xai_method}" ' + f'on dataset "{dataset}"') # == LOADING HYPER PARAMETERS == # Technically the values provided through the command line options do not provide enough information to @@ -103,6 +105,26 @@ def main(model: str, dataset_name=dataset ) + # We already want to check if the xai method is even defined for the given model in the hyperparameters + # if that is not the case we can already terminate here. + if xai_method not in hyper_params['xai_methods']: + raise NotImplementedError(f'It seems you are attempting to use the xai_method "{xai_method}"' + f'with the model "{model_name}". However, there is no entry ' + f'found for this combination in the given hyperparameter file! ' + f'Please add an entry for that xai_method in the "xai_methods" ' + f'section of the model "{model_name}" and try again.') + + # We want to consider every combination of model & explanation method as unique contender in the + # benchmarking, which is why we set the postfix here as the name of the explanation method to make + # sure that each combination gets it's own results folder. + # We also need to modify the model name so that it is correctly displayed later on in the summary as + # an individual entry + if xai_method is not None: + hyper_params._hyper['info']['postfix'] = '_' + xai_method + model_name = model + '_' + xai_method + else: + model_name = model + # == CREATING RESULTS FOLDER == # Since we are about to generate a bunch of artifacts, we create a new directory here where we are going # to save all of those artifacts into, so we don't accidentally clutter an important folder of the user. @@ -197,8 +219,35 @@ def main(model: str, echo_info(f'done processing split {splits_done}') # == CREATING EXPLANATIONS == - # - node_importances, edge_importances = model.explain_importances(x_test) + # Currently this module only supports "importance" explanations, where explanations basically + # consist of attributing each of the input nodes and edges of the original graph with a single + # (0, 1) importance value. + + # Here we need to make a difference between self-explaining models and explaining other models + # using black-box explainability methods. + + if xai_method is None: + node_importances, edge_importances = model.explain_importances(x_test) + + else: + xai_config: dict = hyper_params['xai_methods'][xai_method] + xai_class: type = get_model_class( + class_name=xai_config['class_name'], + module_name=xai_config['module_name'] + ) + xai_instance: AbstractExplanationMethod = xai_class( + channels=num_importance_channels, + **xai_config['config'] + ) + + # __call__ of that instance implements the actual explanation process. + # Beware that this could take some time! + node_importances, edge_importances = xai_instance( + model=model, + x=x_test, + y=y_test, + ) + node_importances = [minmax_scale(a) for a in node_importances.numpy()] edge_importances = [minmax_scale(a) for a in edge_importances.numpy()] @@ -237,30 +286,25 @@ def main(model: str, hist.history['val_edge_auc'] = [edge_auc] # == CREATING VISUALIZATIONS == + test_indices_list = test_indices.tolist() example_indices = random.sample( - test_indices.tolist(), + test_indices_list, k=int(len(test_indices) * visualization_ratio) ) example_indices = sorted(example_indices) - echo_info(f'creating explanation visualizations for {len(example_indices)} indices from test set') - - # Now to create the visualizations for these chosen example elements, we first need to create the - # explanations for them. - x_example = visual_graph_dataset[example_indices].tensor([ - {'name': 'node_attributes', 'ragged': True}, - {'name': 'edge_attributes', 'ragged': True}, - {'name': 'edge_indices', 'ragged': True} - ]) - node_importances, edge_importances = model.explain_importances(x_example) - node_importances = node_importances.numpy() - edge_importances = edge_importances.numpy() - - pdf_path = os.path.join(results_path, f'importances_split_{splits_done}.pdf') + node_importances_example = [] + edge_importances_example = [] + for index in example_indices: + c = test_indices_list.index(index) + node_importances_example.append(node_importances[c]) + edge_importances_example.append(edge_importances[c]) + + pdf_path = os.path.join(results_path, f'examples__split_{splits_done}.pdf') visual_graph_dataset.visualize_importances( output_path=pdf_path, gt_importances_suffix=str(num_importance_channels), - node_importances_list=node_importances, - edge_importances_list=edge_importances, + node_importances_list=node_importances_example, + edge_importances_list=edge_importances_example, indices=example_indices, )