-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate_models_utils.py
105 lines (84 loc) · 6.1 KB
/
evaluate_models_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils.metrics import get_link_classification_metrics
from utils.utils import NegativeEdgeSampler, NeighborSampler
from utils.DataLoader import Data
def evaluate_model_link_classification(model_name: str, model: nn.Module, neighbor_sampler: NeighborSampler, evaluate_idx_data_loader: DataLoader,
evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate_data: Data, loss_func: nn.Module,
num_neighbors: int = 20, time_gap: int = 2000):
"""
evaluate models on the link classification task
:param model_name: str, name of the model
:param model: nn.Module, the model to be evaluated
:param neighbor_sampler: NeighborSampler, neighbor sampler
:param evaluate_idx_data_loader: DataLoader, evaluate index data loader
:param evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate negative edge sampler
:param evaluate_data: Data, data to be evaluated
:param loss_func: nn.Module, loss function
:param num_neighbors: int, number of neighbors to sample for each node
:param time_gap: int, time gap for neighbors to compute node features
:return:
"""
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert evaluate_neg_edge_sampler.seed is not None
evaluate_neg_edge_sampler.reset_random_state()
model[0].set_neighbor_sampler(neighbor_sampler)
model.eval()
with torch.no_grad():
# store evaluate losses and metrics
evaluate_losses, evaluate_metrics = [], []
evaluate_predicts, evaluate_labels = [], []
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
evaluate_data_indices = evaluate_data_indices.numpy()
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids, batch_edge_labels = \
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices],\
evaluate_data.labels[evaluate_data_indices]
if model_name in ['DyGKT','QIKT','IEKT','IPKT','DIMKT','DKT','AKT','CTNCM','simpleKT']:
batch_src_node_embeddings,batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
edge_ids = batch_edge_ids,
node_interact_times=batch_node_interact_times,
dst_node_ids=batch_dst_node_ids)
elif model_name in ['TGAT']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
elif model_name in ['TGN']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=batch_edge_ids,
edges_are_positive=True,
num_neighbors=num_neighbors)
elif model_name in ['DyGFormer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times)
else:
raise ValueError(f"Wrong value for model_name {model_name}!")
# get positive and negative probabilities, shape (batch_size, )
predicts = model[1](batch_src_node_embeddings,batch_dst_node_embeddings).squeeze(dim=-1).sigmoid()
labels = torch.tensor(batch_edge_labels, dtype=torch.float32,device=predicts.device)
loss = loss_func(input=predicts, target=labels)
evaluate_losses.append(loss.item())
evaluate_predicts.append(predicts)
evaluate_labels.append(labels)
evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')
evaluate_predict = torch.cat(evaluate_predicts, dim=0)
evaluate_label = torch.cat(evaluate_labels, dim=0)
evaluate_metrics.append(get_link_classification_metrics(predicts=evaluate_predict, labels=evaluate_label))
return evaluate_losses, evaluate_metrics