Skip to content

Commit

Permalink
Update WLKernel
Browse files Browse the repository at this point in the history
  • Loading branch information
vladislavalerievich committed Dec 7, 2024
1 parent 1c4cc83 commit dab9a8c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 32 deletions.
37 changes: 7 additions & 30 deletions grakel_replace/mixed_single_task_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,12 @@ def forward(
Tensor: The computed kernel matrix.
"""
if x2 is None:
K = self._K_train
# Handle batch dimension if present in x1
if x1.dim() > 2: # We have a batch dimension
batch_size = x1.size(0)
target_size = x1.size(1) # This should be 11 in our case
# Resize K to match the expected dimensions
K = K.unsqueeze(0) # Add batch dimension
# Pad or interpolate K to match target size
if K.size(1) != target_size:
K_resized = torch.zeros(1, target_size, target_size, dtype=K.dtype,
device=K.device)
K_resized[:, :K.size(1), :K.size(2)] = K
K = K_resized
K = K.expand(batch_size, target_size, target_size)
return K.to(dtype=x1.dtype)

# Similar logic for cross-kernel case
test_dataset = GraphDataset.from_networkx(x2)
K = self._wl_kernel(self._train_graph_dataset, test_dataset)
if x1.dim() > 2:
batch_size = x1.size(0)
target_size = x1.size(1)
if K.size(0) != target_size:
K_resized = torch.zeros(target_size, target_size, dtype=K.dtype,
device=K.device)
K_resized[:K.size(0), :K.size(1)] = K
K = K_resized
K = K.unsqueeze(0).expand(batch_size, target_size, target_size)
return K.to(dtype=x1.dtype)
# Return the precomputed training kernel matrix
return self._K_train

# Compute cross-kernel between training graphs and new test graphs
test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs
return self._wl_kernel(self._train_graph_dataset, test_dataset)


class MixedSingleTaskGP(SingleTaskGP):
Expand Down Expand Up @@ -185,7 +162,7 @@ def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal:
if not all(isinstance(g, nx.Graph) for g in graphs):
raise TypeError("Expected input type is a list of NetworkX graphs.")

# Process the new graph inputs into a compatible dataset
# Process the new graph inputs into a compatible dataset
proc_graphs = GraphDataset.from_networkx(graphs)

# Compute the kernel matrix for the new graphs
Expand Down
4 changes: 2 additions & 2 deletions grakel_replace/mixed_single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@

print("\nMean:", predictions)
print("Variance:", uncertainties)
print("Covariance matrix:", covar)

# =============== Fitting the GP using botorch ===============

print("\nFitting the GP model using botorch...")

mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
Expand Down Expand Up @@ -152,7 +152,7 @@
bounds=bounds,
fixed_features_list=fixed_cats,
train_graphs=train_graphs,
num_graph_samples=6,
num_graph_samples=20,
num_restarts=10,
raw_samples=10,
q=1,
Expand Down

0 comments on commit dab9a8c

Please sign in to comment.