Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions test/data/test_stzinbgraph_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch

from torch_geometric.data import STZINBGraph
from torch_geometric.loader import DataLoader


def test_stzinb_graph():
# Test data
od_pairs = [(0, 1), (1, 2), (2, 3)] # Example O-D pairs
demand = torch.rand(4, 10) # 4 nodes (O-D pairs), 10 time windows
time_windows = [
"t1", "t2", "t3", "t4", "t5", "t6", "t7", "t8", "t9", "t10"
]

# Create an O-D graph
od_graph = STZINBGraph.from_od_data(od_pairs, demand, time_windows)

# Compute the number of unique nodes
unique_nodes = set()
for o, d in od_pairs:
unique_nodes.add(o)
unique_nodes.add(d)
num_nodes = len(unique_nodes)

# Assertions for edge_index
assert od_graph.edge_index is not None, "Edge index should not be None."
assert (od_graph.edge_index.size(0) == 2
), "Edge index should have two rows (source, target)."
assert (od_graph.edge_index.size(1)
> 0), "Edge index should have at least one edge."

# Assertions for node features (x)
assert od_graph.x is not None, "Node features (x) should not be None."
assert (
od_graph.x.size(0) == num_nodes
), f"Node features should match the number of unique nodes ({num_nodes})."
assert (
od_graph.x.size(1) == 1
), "Node features should have a single feature dim (aggregated demand)."

# Assertions for temporal features (time_series)
assert (od_graph.time_series
is not None), "Temporal features (time_series) should not be None."
assert (od_graph.time_series.size() == demand.size()
), "Temporal features should match the input demand matrix."

# Assertions for adjacency matrix (adj)
assert od_graph.adj is not None, (
"Adjacency matrix (adj) should not be None.")

assert (od_graph.adj.size(0) == num_nodes
), f"Adjacency matrix should be square ({num_nodes}x{num_nodes})."
assert (od_graph.adj.size(1) == num_nodes
), f"Adjacency matrix should be square ({num_nodes}x{num_nodes})."
assert torch.all(od_graph.adj.diagonal() ==
1), "Self-loops should exist (diagonal entries = 1)."

# Integration with DataLoader
graphs = [od_graph] * 5 # Create a small dataset with 5 identical graphs
loader = DataLoader(graphs, batch_size=2, shuffle=True)

for batch in loader:
assert batch.batch is not None, (
"Batch object should have a 'batch' attribute.")
assert batch.x is not None, (
"Batched data should have node features (x).")
assert batch.edge_index is not None, (
"Batched data should have edge indices.")
assert batch.time_series is not None, (
"Batched data should retain temporal features (time_series).")

print("STZINBGraph test passed successfully!")


# Run the test
if __name__ == "__main__":
test_stzinb_graph()
56 changes: 56 additions & 0 deletions test/nn/losses/test_zinb_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch

from torch_geometric.nn.losses import ZINBLoss


def test_zinb_loss():
# Create dummy data
target = torch.tensor([0.0, 1.0, 2.0, 0.0, 4.0], dtype=torch.float32)
mu = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0],
dtype=torch.float32) # Mean of NB
theta = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0],
dtype=torch.float32) # Dispersion
pi = torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1],
dtype=torch.float32) # Zero-inflation prob

# Initialize the loss
loss_fn = ZINBLoss()

# Compute the loss
loss = loss_fn((mu, theta, pi), target)

# Check the loss value
assert loss >= 0, "Loss should be non-negative."
assert torch.isfinite(
loss), "Loss should not contain NaN or infinite values."

# Print the loss value for debugging
print(f"ZINB Loss: {loss.item()}")


def test_zinb_loss_shapes():
# Test with batched data
batch_size = 4
target = torch.rand(batch_size, 10) # 10 targets per batch
mu = torch.rand(batch_size, 10) + 0.1 # Mean of NB, avoiding zero
theta = torch.rand(batch_size, 10) + 0.1 # Dispersion, avoiding zero
pi = torch.rand(batch_size, 10) # Zero-inflation probability

# Initialize the loss
loss_fn = ZINBLoss()

# Compute the loss
loss = loss_fn((mu, theta, pi), target)

# Check the loss value
assert loss >= 0, "Loss should be non-negative."
assert torch.isfinite(
loss), "Loss should not contain NaN or infinite values."

# Print the loss value for debugging
print(f"Batched ZINB Loss: {loss.item()}")


if __name__ == "__main__":
test_zinb_loss()
test_zinb_loss_shapes()
65 changes: 65 additions & 0 deletions test/nn/models/test_stzinb_gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch

from torch_geometric.data import Batch, Data
from torch_geometric.nn import STZINBGNN


def test_stzinb_gnn():
# Define model parameters
num_nodes = 50
num_features = 10
time_window = 5
hidden_dim_s = 70
hidden_dim_t = 7
rank_s = 20
rank_t = 4
k = 4
batch_size = 8

# Create dummy data for testing
edge_index = torch.randint(0, num_nodes, (2, 200),
dtype=torch.long) # Random edges
x = torch.rand(num_nodes, num_features) # Random node features
y = torch.randint(0, 10, (num_nodes, k)) # Random target for ZINB loss

# Create PyG Data object
data = Data(x=x, edge_index=edge_index, y=y)

# Create a batch of graphs
data_list = [data.clone() for _ in range(batch_size)]
batch = Batch.from_data_list(data_list) # Use Batch to create a batch

# Initialize the model
model = STZINBGNN(
num_nodes=num_nodes,
num_features=num_features,
time_window=time_window,
hidden_dim_s=hidden_dim_s,
hidden_dim_t=hidden_dim_t,
rank_s=rank_s,
rank_t=rank_t,
k=k,
)

# Forward pass
pi, n, p = model(batch.x, batch.edge_index, batch.batch)

# Assertions for output shapes
assert pi.shape == (batch_size * num_nodes,
1), "Incorrect shape for π (zero-inflation parameter)"
assert n.shape == (batch_size * num_nodes,
1), "Incorrect shape for n (shape parameter)"
assert p.shape == (batch_size * num_nodes,
1), "Incorrect shape for p (probability parameter)"

# Check value ranges
assert (pi >= 0).all() and (pi <= 1).all(), "π values out of range [0, 1]"
assert (n > 0).all(), "n values should be strictly positive"
assert (p >= 0).all() and (p <= 1).all(), "p values out of range [0, 1]"

print("STZINBGNN test passed successfully!")


# Run the test
if __name__ == "__main__":
test_stzinb_gnn()
1 change: 1 addition & 0 deletions torch_geometric/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .download import download_url, download_google_url
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups
from .stzinbgraph_data import STZINBGraph

from torch_geometric.lazy_loader import LazyLoader

Expand Down
99 changes: 99 additions & 0 deletions torch_geometric/data/stzinbgraph_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch

from torch_geometric.data import Data


class STZINBGraph(Data):
"""A class to represent an Origin-Destination (O-D) graph with temporal
travel demand information.

Attributes:
edge_index (torch.Tensor): Tensor defining the edges between O-D pairs.
x (torch.Tensor): Node features representing O-D pairs.
time_series (torch.Tensor): Temporal features related to travel demand
across time windows.
adj (torch.Tensor): Adjacency matrix for the O-D graph.
"""
def __init__(self, edge_index=None, x=None, time_series=None, adj=None,
**kwargs):
super().__init__(edge_index=edge_index, x=x, **kwargs)
self.time_series = time_series
self.adj = adj

@staticmethod
def from_od_data(od_pairs, demand, time_windows, fully_connected=True,
normalize_adj=True):
"""Builds an O-D graph from raw data including O-D pairs, demand
matrices, and time windows.

Args:
od_pairs (list of tuples): List of O-D pairs (e.g., [(origin1,
dest1), (origin2, dest2), ...]).
demand (torch.Tensor): Demand matrix of shape [num_nodes,
num_time_windows].
time_windows (list): List of time window labels (e.g., ['t1',
't2', ..., 'tn']).
fully_connected (bool): If True, creates a fully connected graph.
Defaults to True.
normalize_adj (bool): If True, normalizes the adjacency matrix.
Defaults to True.

Returns:
STZINBGraph: A graph object with node features, edges, adjacency,
and temporal features.
"""
# Determine the set of unique nodes
unique_nodes = set()
for o, d in od_pairs:
unique_nodes.add(o)
unique_nodes.add(d)
unique_nodes = sorted(unique_nodes)
num_nodes = len(unique_nodes)

# Map original indices to continuous indices if necessary
node_mapping = {node: idx for idx, node in enumerate(unique_nodes)}
od_pairs_mapped = [(node_mapping[o], node_mapping[d])
for o, d in od_pairs]

# Create edge_index
if fully_connected:
# Fully connected graph
edge_index = torch.combinations(torch.arange(num_nodes), r=2).t()
edge_index = torch.cat([edge_index, edge_index.flip(0)],
dim=1) # Symmetric edges
else:
# Custom graph based on O-D adjacency
edges = [(o, d) for o, d in od_pairs_mapped]
edge_index = torch.tensor(edges, dtype=torch.long).t()

# Build adjacency matrix
adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
for o, d in od_pairs_mapped:
adj[o, d] = 1 # Directed edge
adj[d, o] = 1 # Undirected graph (symmetric adjacency matrix)

# Add self-loops
adj += torch.eye(num_nodes)

# Normalize adjacency matrix (excluding self-loops)
if normalize_adj:
row_sums = adj.sum(
dim=1, keepdim=True) - 1 # Subtract self-loop contribution
row_sums[row_sums == 0] = 1 # Avoid division by zero
non_diag_mask = ~torch.eye(num_nodes, dtype=torch.bool)
adj[non_diag_mask] /= row_sums.expand_as(adj)[non_diag_mask]

# Adjust demand to match the number of nodes
if demand.size(0) != num_nodes:
raise ValueError(
f"The number of rows in the demand matrix ({demand.size(0)}) "
f"must match the number of unique nodes ({num_nodes}).")

# Node features: Aggregate demand over all time windows
x = demand.mean(dim=1, keepdim=True) # Example: mean demand per node

# Temporal features: Full time-series demand
time_series = demand

return STZINBGraph(edge_index=edge_index, x=x, time_series=time_series,
adj=adj)
1 change: 1 addition & 0 deletions torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .to_fixed_size_transformer import to_fixed_size
from .encoding import PositionalEncoding, TemporalEncoding
from .summary import summary
# from .losses import ZINBLoss

from .aggr import * # noqa
from .conv import * # noqa
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/nn/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# from .zinb_loss import ZINBLoss
70 changes: 70 additions & 0 deletions torch_geometric/nn/losses/zinb_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch.nn as nn


class ZINBLoss(nn.Module):
"""Custom loss function for the Zero-Inflated Negative Binomial (ZINB)
distribution.

Args:
eps (float): A small constant to avoid division by zero or log(0).
"""
def __init__(self, eps=1e-8):
super().__init__()
self.eps = eps

def forward(self, prediction, target):
"""Computes the negative log-likelihood of the ZINB distribution.

Args:
prediction (tuple): A tuple containing (mu, theta, pi) from the
model:
- mu: Mean of the Negative Binomial distribution.
- theta: Dispersion parameter (greater than 0).
- pi: Zero-inflation probability (between 0 and 1).
target (torch.Tensor): Ground truth values.

Returns:
torch.Tensor: The computed ZINB loss.
"""
mu, theta, pi = prediction
return self.compute_zinb_loss(mu, theta, pi, target)

def compute_zinb_loss(self, mu, theta, pi, target):
"""Computes the Zero-Inflated Negative Binomial loss components.

Args:
mu (torch.Tensor): Mean of the Negative Binomial distribution.
theta (torch.Tensor): Dispersion parameter (greater than 0).
pi (torch.Tensor): Zero-inflation probability (between 0 and 1).
target (torch.Tensor): Ground truth values.

Returns:
torch.Tensor: The computed ZINB loss.
"""
# Ensure valid values for stability
mu = torch.clamp(mu, min=self.eps)
theta = torch.clamp(theta, min=self.eps)
pi = torch.clamp(pi, min=self.eps, max=1 - self.eps)
target = torch.clamp(target, min=self.eps)

# Log-likelihood of the Negative Binomial (NB) component
log_nb = (torch.lgamma(theta + target) - torch.lgamma(target + 1) -
torch.lgamma(theta) + theta * torch.log(theta) +
target * torch.log(mu) -
(theta + target) * torch.log(theta + mu))

# Log-likelihood of the zero-inflated component
log_zero_inflated = torch.log(pi + (1 - pi) * torch.exp(log_nb))

# Log-likelihood for non-zero values
log_non_zero = torch.log(1 - pi) + log_nb

# Combine likelihoods based on target values
zinb_loss = torch.where(
target < self.eps, # If the target is zero
-log_zero_inflated, # Use zero-inflated likelihood
-log_non_zero, # Use regular NB likelihood
)

return zinb_loss.mean() # Return the mean loss across all samples
Loading