-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
100 lines (85 loc) · 3.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import json
import os
import pickle
import numpy as np
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import Amazon, Coauthor, Planetoid
EPS=1e-15
ALL_DATASETS = [
"cora",
"citeseer",
"pubmed",
"cs",
"physics",
"computers",
"photo"
]
class Mask(object):
def __init__(self, train_mask, val_mask, test_mask):
self.train_mask = train_mask
self.val_mask = val_mask
self.test_mask = test_mask
def load_dataset(dataset, transform=None):
if dataset.lower() in ["cora", "citeseer", "pubmed"]:
path = os.path.join(".datasets", "Plantoid")
dataset = Planetoid(path, dataset.lower(), transform=transform)
elif dataset.lower() in ["cs", "physics"]:
path = os.path.join(".datasets", "Coauthor", dataset.lower())
dataset = Coauthor(path, dataset.lower(), transform=transform)
elif dataset.lower() in ["computers", "photo"]:
path = os.path.join(".datasets", "Amazon", dataset.lower())
dataset = Amazon(path, dataset.lower(), transform=transform)
else:
print("Dataset not supported!")
assert False
return dataset
def generate_split(dataset, seed=0, train_num_per_c=20, val_num_per_c=30):
torch.manual_seed(seed)
dataset = load_dataset(dataset)
data = dataset[0]
num_classes = dataset.num_classes
train_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
val_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
test_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
for c in range(num_classes):
all_c_idx = (data.y == c).nonzero()
if all_c_idx.size(0) <= train_num_per_c + val_num_per_c:
test_mask[all_c_idx] = True
continue
perm = torch.randperm(all_c_idx.size(0))
c_train_idx = all_c_idx[perm[:train_num_per_c]]
train_mask[c_train_idx] = True
test_mask[c_train_idx] = True
c_val_idx = all_c_idx[perm[train_num_per_c : train_num_per_c + val_num_per_c]]
val_mask[c_val_idx] = True
test_mask[c_val_idx] = True
test_mask = ~test_mask
return train_mask, val_mask, test_mask
def generate_percent_split(dataset, seed, train_percent=70, val_percent=20):
torch.manual_seed(seed)
dataset = load_dataset(dataset)
data = dataset[0]
num_classes = dataset.num_classes
train_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
val_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
test_mask = torch.zeros(data.y.size(0), dtype=torch.bool)
for c in range(num_classes):
all_c_idx = (data.y == c).nonzero().flatten()
num_c = all_c_idx.size(0)
train_num_per_c = num_c * train_percent // 100
val_num_per_c = num_c * val_percent // 100
perm = torch.randperm(all_c_idx.size(0))
c_train_idx = all_c_idx[perm[:train_num_per_c]]
train_mask[c_train_idx] = True
test_mask[c_train_idx] = True
c_val_idx = all_c_idx[perm[train_num_per_c : train_num_per_c + val_num_per_c]]
val_mask[c_val_idx] = True
test_mask[c_val_idx] = True
test_mask = ~test_mask
return train_mask, val_mask, test_mask
def load_split(path):
mask = torch.load(path)
return mask.train_mask, mask.val_mask, mask.test_mask