From dab9a8c81c2d54a3fd07380c5852f561f4182753 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 22:46:45 +0100 Subject: [PATCH] Update WLKernel --- grakel_replace/mixed_single_task_gp.py | 37 ++++--------------- .../mixed_single_task_gp_usage_example.py | 4 +- 2 files changed, 9 insertions(+), 32 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index de04a14a..9f322d05 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -55,35 +55,12 @@ def forward( Tensor: The computed kernel matrix. """ if x2 is None: - K = self._K_train - # Handle batch dimension if present in x1 - if x1.dim() > 2: # We have a batch dimension - batch_size = x1.size(0) - target_size = x1.size(1) # This should be 11 in our case - # Resize K to match the expected dimensions - K = K.unsqueeze(0) # Add batch dimension - # Pad or interpolate K to match target size - if K.size(1) != target_size: - K_resized = torch.zeros(1, target_size, target_size, dtype=K.dtype, - device=K.device) - K_resized[:, :K.size(1), :K.size(2)] = K - K = K_resized - K = K.expand(batch_size, target_size, target_size) - return K.to(dtype=x1.dtype) - - # Similar logic for cross-kernel case - test_dataset = GraphDataset.from_networkx(x2) - K = self._wl_kernel(self._train_graph_dataset, test_dataset) - if x1.dim() > 2: - batch_size = x1.size(0) - target_size = x1.size(1) - if K.size(0) != target_size: - K_resized = torch.zeros(target_size, target_size, dtype=K.dtype, - device=K.device) - K_resized[:K.size(0), :K.size(1)] = K - K = K_resized - K = K.unsqueeze(0).expand(batch_size, target_size, target_size) - return K.to(dtype=x1.dtype) + # Return the precomputed training kernel matrix + return self._K_train + + # Compute cross-kernel between training graphs and new test graphs + test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs + return self._wl_kernel(self._train_graph_dataset, test_dataset) class MixedSingleTaskGP(SingleTaskGP): @@ -185,7 +162,7 @@ def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: if not all(isinstance(g, nx.Graph) for g in graphs): raise TypeError("Expected input type is a list of NetworkX graphs.") - # Process the new graph inputs into a compatible dataset + # Process the new graph inputs into a compatible dataset proc_graphs = GraphDataset.from_networkx(graphs) # Compute the kernel matrix for the new graphs diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py index 19d1662b..cd8528e4 100644 --- a/grakel_replace/mixed_single_task_gp_usage_example.py +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -105,9 +105,9 @@ print("\nMean:", predictions) print("Variance:", uncertainties) -print("Covariance matrix:", covar) # =============== Fitting the GP using botorch =============== + print("\nFitting the GP model using botorch...") mll = ExactMarginalLogLikelihood(gp.likelihood, gp) @@ -152,7 +152,7 @@ bounds=bounds, fixed_features_list=fixed_cats, train_graphs=train_graphs, - num_graph_samples=6, + num_graph_samples=20, num_restarts=10, raw_samples=10, q=1,