diff --git a/test/data/test_stzinbgraph_data.py b/test/data/test_stzinbgraph_data.py new file mode 100644 index 000000000000..6969925d8020 --- /dev/null +++ b/test/data/test_stzinbgraph_data.py @@ -0,0 +1,77 @@ +import torch + +from torch_geometric.data import STZINBGraph +from torch_geometric.loader import DataLoader + + +def test_stzinb_graph(): + # Test data + od_pairs = [(0, 1), (1, 2), (2, 3)] # Example O-D pairs + demand = torch.rand(4, 10) # 4 nodes (O-D pairs), 10 time windows + time_windows = [ + "t1", "t2", "t3", "t4", "t5", "t6", "t7", "t8", "t9", "t10" + ] + + # Create an O-D graph + od_graph = STZINBGraph.from_od_data(od_pairs, demand, time_windows) + + # Compute the number of unique nodes + unique_nodes = set() + for o, d in od_pairs: + unique_nodes.add(o) + unique_nodes.add(d) + num_nodes = len(unique_nodes) + + # Assertions for edge_index + assert od_graph.edge_index is not None, "Edge index should not be None." + assert (od_graph.edge_index.size(0) == 2 + ), "Edge index should have two rows (source, target)." + assert (od_graph.edge_index.size(1) + > 0), "Edge index should have at least one edge." + + # Assertions for node features (x) + assert od_graph.x is not None, "Node features (x) should not be None." + assert ( + od_graph.x.size(0) == num_nodes + ), f"Node features should match the number of unique nodes ({num_nodes})." + assert ( + od_graph.x.size(1) == 1 + ), "Node features should have a single feature dim (aggregated demand)." + + # Assertions for temporal features (time_series) + assert (od_graph.time_series + is not None), "Temporal features (time_series) should not be None." + assert (od_graph.time_series.size() == demand.size() + ), "Temporal features should match the input demand matrix." + + # Assertions for adjacency matrix (adj) + assert od_graph.adj is not None, ( + "Adjacency matrix (adj) should not be None.") + + assert (od_graph.adj.size(0) == num_nodes + ), f"Adjacency matrix should be square ({num_nodes}x{num_nodes})." + assert (od_graph.adj.size(1) == num_nodes + ), f"Adjacency matrix should be square ({num_nodes}x{num_nodes})." + assert torch.all(od_graph.adj.diagonal() == + 1), "Self-loops should exist (diagonal entries = 1)." + + # Integration with DataLoader + graphs = [od_graph] * 5 # Create a small dataset with 5 identical graphs + loader = DataLoader(graphs, batch_size=2, shuffle=True) + + for batch in loader: + assert batch.batch is not None, ( + "Batch object should have a 'batch' attribute.") + assert batch.x is not None, ( + "Batched data should have node features (x).") + assert batch.edge_index is not None, ( + "Batched data should have edge indices.") + assert batch.time_series is not None, ( + "Batched data should retain temporal features (time_series).") + + print("STZINBGraph test passed successfully!") + + +# Run the test +if __name__ == "__main__": + test_stzinb_graph() diff --git a/test/nn/losses/test_zinb_loss.py b/test/nn/losses/test_zinb_loss.py new file mode 100644 index 000000000000..8c6ea3c215e8 --- /dev/null +++ b/test/nn/losses/test_zinb_loss.py @@ -0,0 +1,56 @@ +import torch + +from torch_geometric.nn.losses import ZINBLoss + + +def test_zinb_loss(): + # Create dummy data + target = torch.tensor([0.0, 1.0, 2.0, 0.0, 4.0], dtype=torch.float32) + mu = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], + dtype=torch.float32) # Mean of NB + theta = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], + dtype=torch.float32) # Dispersion + pi = torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1], + dtype=torch.float32) # Zero-inflation prob + + # Initialize the loss + loss_fn = ZINBLoss() + + # Compute the loss + loss = loss_fn((mu, theta, pi), target) + + # Check the loss value + assert loss >= 0, "Loss should be non-negative." + assert torch.isfinite( + loss), "Loss should not contain NaN or infinite values." + + # Print the loss value for debugging + print(f"ZINB Loss: {loss.item()}") + + +def test_zinb_loss_shapes(): + # Test with batched data + batch_size = 4 + target = torch.rand(batch_size, 10) # 10 targets per batch + mu = torch.rand(batch_size, 10) + 0.1 # Mean of NB, avoiding zero + theta = torch.rand(batch_size, 10) + 0.1 # Dispersion, avoiding zero + pi = torch.rand(batch_size, 10) # Zero-inflation probability + + # Initialize the loss + loss_fn = ZINBLoss() + + # Compute the loss + loss = loss_fn((mu, theta, pi), target) + + # Check the loss value + assert loss >= 0, "Loss should be non-negative." + assert torch.isfinite( + loss), "Loss should not contain NaN or infinite values." + + # Print the loss value for debugging + print(f"Batched ZINB Loss: {loss.item()}") + + +if __name__ == "__main__": + test_zinb_loss() + test_zinb_loss_shapes() diff --git a/test/nn/models/test_stzinb_gnn.py b/test/nn/models/test_stzinb_gnn.py new file mode 100644 index 000000000000..47e82264f770 --- /dev/null +++ b/test/nn/models/test_stzinb_gnn.py @@ -0,0 +1,65 @@ +import torch + +from torch_geometric.data import Batch, Data +from torch_geometric.nn import STZINBGNN + + +def test_stzinb_gnn(): + # Define model parameters + num_nodes = 50 + num_features = 10 + time_window = 5 + hidden_dim_s = 70 + hidden_dim_t = 7 + rank_s = 20 + rank_t = 4 + k = 4 + batch_size = 8 + + # Create dummy data for testing + edge_index = torch.randint(0, num_nodes, (2, 200), + dtype=torch.long) # Random edges + x = torch.rand(num_nodes, num_features) # Random node features + y = torch.randint(0, 10, (num_nodes, k)) # Random target for ZINB loss + + # Create PyG Data object + data = Data(x=x, edge_index=edge_index, y=y) + + # Create a batch of graphs + data_list = [data.clone() for _ in range(batch_size)] + batch = Batch.from_data_list(data_list) # Use Batch to create a batch + + # Initialize the model + model = STZINBGNN( + num_nodes=num_nodes, + num_features=num_features, + time_window=time_window, + hidden_dim_s=hidden_dim_s, + hidden_dim_t=hidden_dim_t, + rank_s=rank_s, + rank_t=rank_t, + k=k, + ) + + # Forward pass + pi, n, p = model(batch.x, batch.edge_index, batch.batch) + + # Assertions for output shapes + assert pi.shape == (batch_size * num_nodes, + 1), "Incorrect shape for π (zero-inflation parameter)" + assert n.shape == (batch_size * num_nodes, + 1), "Incorrect shape for n (shape parameter)" + assert p.shape == (batch_size * num_nodes, + 1), "Incorrect shape for p (probability parameter)" + + # Check value ranges + assert (pi >= 0).all() and (pi <= 1).all(), "π values out of range [0, 1]" + assert (n > 0).all(), "n values should be strictly positive" + assert (p >= 0).all() and (p <= 1).all(), "p values out of range [0, 1]" + + print("STZINBGNN test passed successfully!") + + +# Run the test +if __name__ == "__main__": + test_stzinb_gnn() diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index fee215b1a357..f94a4c52d89b 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -17,6 +17,7 @@ from .download import download_url, download_google_url from .extract import extract_tar, extract_zip, extract_bz2, extract_gz from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups +from .stzinbgraph_data import STZINBGraph from torch_geometric.lazy_loader import LazyLoader diff --git a/torch_geometric/data/stzinbgraph_data.py b/torch_geometric/data/stzinbgraph_data.py new file mode 100644 index 000000000000..f9d311401a9a --- /dev/null +++ b/torch_geometric/data/stzinbgraph_data.py @@ -0,0 +1,99 @@ +import torch + +from torch_geometric.data import Data + + +class STZINBGraph(Data): + """A class to represent an Origin-Destination (O-D) graph with temporal + travel demand information. + + Attributes: + edge_index (torch.Tensor): Tensor defining the edges between O-D pairs. + x (torch.Tensor): Node features representing O-D pairs. + time_series (torch.Tensor): Temporal features related to travel demand + across time windows. + adj (torch.Tensor): Adjacency matrix for the O-D graph. + """ + def __init__(self, edge_index=None, x=None, time_series=None, adj=None, + **kwargs): + super().__init__(edge_index=edge_index, x=x, **kwargs) + self.time_series = time_series + self.adj = adj + + @staticmethod + def from_od_data(od_pairs, demand, time_windows, fully_connected=True, + normalize_adj=True): + """Builds an O-D graph from raw data including O-D pairs, demand + matrices, and time windows. + + Args: + od_pairs (list of tuples): List of O-D pairs (e.g., [(origin1, + dest1), (origin2, dest2), ...]). + demand (torch.Tensor): Demand matrix of shape [num_nodes, + num_time_windows]. + time_windows (list): List of time window labels (e.g., ['t1', + 't2', ..., 'tn']). + fully_connected (bool): If True, creates a fully connected graph. + Defaults to True. + normalize_adj (bool): If True, normalizes the adjacency matrix. + Defaults to True. + + Returns: + STZINBGraph: A graph object with node features, edges, adjacency, + and temporal features. + """ + # Determine the set of unique nodes + unique_nodes = set() + for o, d in od_pairs: + unique_nodes.add(o) + unique_nodes.add(d) + unique_nodes = sorted(unique_nodes) + num_nodes = len(unique_nodes) + + # Map original indices to continuous indices if necessary + node_mapping = {node: idx for idx, node in enumerate(unique_nodes)} + od_pairs_mapped = [(node_mapping[o], node_mapping[d]) + for o, d in od_pairs] + + # Create edge_index + if fully_connected: + # Fully connected graph + edge_index = torch.combinations(torch.arange(num_nodes), r=2).t() + edge_index = torch.cat([edge_index, edge_index.flip(0)], + dim=1) # Symmetric edges + else: + # Custom graph based on O-D adjacency + edges = [(o, d) for o, d in od_pairs_mapped] + edge_index = torch.tensor(edges, dtype=torch.long).t() + + # Build adjacency matrix + adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float) + for o, d in od_pairs_mapped: + adj[o, d] = 1 # Directed edge + adj[d, o] = 1 # Undirected graph (symmetric adjacency matrix) + + # Add self-loops + adj += torch.eye(num_nodes) + + # Normalize adjacency matrix (excluding self-loops) + if normalize_adj: + row_sums = adj.sum( + dim=1, keepdim=True) - 1 # Subtract self-loop contribution + row_sums[row_sums == 0] = 1 # Avoid division by zero + non_diag_mask = ~torch.eye(num_nodes, dtype=torch.bool) + adj[non_diag_mask] /= row_sums.expand_as(adj)[non_diag_mask] + + # Adjust demand to match the number of nodes + if demand.size(0) != num_nodes: + raise ValueError( + f"The number of rows in the demand matrix ({demand.size(0)}) " + f"must match the number of unique nodes ({num_nodes}).") + + # Node features: Aggregate demand over all time windows + x = demand.mean(dim=1, keepdim=True) # Example: mean demand per node + + # Temporal features: Full time-series demand + time_series = demand + + return STZINBGraph(edge_index=edge_index, x=x, time_series=time_series, + adj=adj) diff --git a/torch_geometric/nn/__init__.py b/torch_geometric/nn/__init__.py index 5c615d6e9b45..217f43e898f9 100644 --- a/torch_geometric/nn/__init__.py +++ b/torch_geometric/nn/__init__.py @@ -6,6 +6,7 @@ from .to_fixed_size_transformer import to_fixed_size from .encoding import PositionalEncoding, TemporalEncoding from .summary import summary +# from .losses import ZINBLoss from .aggr import * # noqa from .conv import * # noqa diff --git a/torch_geometric/nn/losses/__init__.py b/torch_geometric/nn/losses/__init__.py new file mode 100644 index 000000000000..9c8891a081ed --- /dev/null +++ b/torch_geometric/nn/losses/__init__.py @@ -0,0 +1 @@ +# from .zinb_loss import ZINBLoss diff --git a/torch_geometric/nn/losses/zinb_loss.py b/torch_geometric/nn/losses/zinb_loss.py new file mode 100644 index 000000000000..d420afe7b5e4 --- /dev/null +++ b/torch_geometric/nn/losses/zinb_loss.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + + +class ZINBLoss(nn.Module): + """Custom loss function for the Zero-Inflated Negative Binomial (ZINB) + distribution. + + Args: + eps (float): A small constant to avoid division by zero or log(0). + """ + def __init__(self, eps=1e-8): + super().__init__() + self.eps = eps + + def forward(self, prediction, target): + """Computes the negative log-likelihood of the ZINB distribution. + + Args: + prediction (tuple): A tuple containing (mu, theta, pi) from the + model: + - mu: Mean of the Negative Binomial distribution. + - theta: Dispersion parameter (greater than 0). + - pi: Zero-inflation probability (between 0 and 1). + target (torch.Tensor): Ground truth values. + + Returns: + torch.Tensor: The computed ZINB loss. + """ + mu, theta, pi = prediction + return self.compute_zinb_loss(mu, theta, pi, target) + + def compute_zinb_loss(self, mu, theta, pi, target): + """Computes the Zero-Inflated Negative Binomial loss components. + + Args: + mu (torch.Tensor): Mean of the Negative Binomial distribution. + theta (torch.Tensor): Dispersion parameter (greater than 0). + pi (torch.Tensor): Zero-inflation probability (between 0 and 1). + target (torch.Tensor): Ground truth values. + + Returns: + torch.Tensor: The computed ZINB loss. + """ + # Ensure valid values for stability + mu = torch.clamp(mu, min=self.eps) + theta = torch.clamp(theta, min=self.eps) + pi = torch.clamp(pi, min=self.eps, max=1 - self.eps) + target = torch.clamp(target, min=self.eps) + + # Log-likelihood of the Negative Binomial (NB) component + log_nb = (torch.lgamma(theta + target) - torch.lgamma(target + 1) - + torch.lgamma(theta) + theta * torch.log(theta) + + target * torch.log(mu) - + (theta + target) * torch.log(theta + mu)) + + # Log-likelihood of the zero-inflated component + log_zero_inflated = torch.log(pi + (1 - pi) * torch.exp(log_nb)) + + # Log-likelihood for non-zero values + log_non_zero = torch.log(1 - pi) + log_nb + + # Combine likelihoods based on target values + zinb_loss = torch.where( + target < self.eps, # If the target is zero + -log_zero_inflated, # Use zero-inflated likelihood + -log_non_zero, # Use regular NB likelihood + ) + + return zinb_loss.mean() # Return the mean loss across all samples diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 9ade58cebc05..d33731666193 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -32,6 +32,7 @@ from .git_mol import GITMol from .molecule_gpt import MoleculeGPT from .glem import GLEM +from .stzinb_gnn import STZINBGNN # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -82,4 +83,5 @@ 'GITMol', 'MoleculeGPT', 'GLEM', + 'STZINBGNN', ] diff --git a/torch_geometric/nn/models/stzinb_gnn.py b/torch_geometric/nn/models/stzinb_gnn.py new file mode 100644 index 000000000000..3ed10afce5b1 --- /dev/null +++ b/torch_geometric/nn/models/stzinb_gnn.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from torch.nn.functional import softplus + +from torch_geometric.nn import GCNConv + + +class STZINBGNN(nn.Module): + def __init__(self, num_nodes, num_features, time_window, hidden_dim_s, + hidden_dim_t, rank_s, rank_t, k): + super().__init__() + + # Spatial Layers (Replaces D_GCN) + self.spatial_conv1 = GCNConv(num_features, hidden_dim_s) + self.spatial_conv2 = GCNConv(hidden_dim_s, rank_s) + self.spatial_conv3 = GCNConv(rank_s, hidden_dim_s) + + # Temporal Layers (Replaces B_TCN) + self.temporal_conv1 = nn.Conv1d(in_channels=hidden_dim_s, + out_channels=hidden_dim_t, + kernel_size=3, padding=1) + self.temporal_conv2 = nn.Conv1d(in_channels=hidden_dim_t, + out_channels=rank_t, kernel_size=3, + padding=1) + self.temporal_conv3 = nn.Conv1d(in_channels=rank_t, + out_channels=hidden_dim_t, + kernel_size=3, padding=1) + + # ZINB Layer Parameters + self.fc_pi = nn.Linear(hidden_dim_t, 1) # Zero-inflation parameter + self.fc_n = nn.Linear(hidden_dim_t, 1) # Shape parameter + self.fc_p = nn.Linear(hidden_dim_t, 1) # Probability parameter + + # Time windows for prediction + self.k = k + + def forward(self, x, edge_index, batch): + # Spatial Embedding + x_s = torch.relu(self.spatial_conv1(x, edge_index)) + x_s = torch.relu(self.spatial_conv2(x_s, edge_index)) + x_s = torch.relu(self.spatial_conv3(x_s, edge_index)) + + # Reshape for Temporal Processing + num_graphs = batch.max().item() + 1 + x_s = x_s.view(num_graphs, -1, + x_s.size(-1)) # [num_graphs, num_nodes, hidden_dim_s] + x_s = x_s.permute(0, 2, + 1) # For TCN: [batch_size, hidden_dim_s, num_nodes] + + # Temporal Embedding + x_t = torch.relu(self.temporal_conv1(x_s)) + x_t = torch.relu(self.temporal_conv2(x_t)) + x_t = torch.relu(self.temporal_conv3(x_t)) + + # Flatten for ZINB parameterization + x_t = x_t.permute(0, 2, 1).reshape( + -1, x_t.size(1)) # Combine spatial and temporal features + + # ZINB Parameters + pi = torch.sigmoid( + self.fc_pi(x_t)) # Shape: [batch_size * num_nodes, 1] + n = softplus(self.fc_n(x_t)) # Shape: [batch_size * num_nodes, 1] + p = torch.sigmoid(self.fc_p(x_t)) # Shape: [batch_size * num_nodes, 1] + + return pi, n, p