forked from adinba/Modelling_Seminar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Data_toy_classification.py
52 lines (40 loc) · 1.48 KB
/
Data_toy_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch_geometric.data import Data, InMemoryDataset
class CustomGraphDataset(InMemoryDataset):
def __init__(self, root, data_list=None, transform=None, pre_transform=None):
self.data_list = data_list
super(CustomGraphDataset, self).__init__(root, transform, pre_transform)
if data_list is not None:
self.data, self.slices = self.collate(self.data_list)
else:
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return []
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
pass
def process(self):
if self.data_list is not None:
self.data, self.slices = self.collate(self.data_list)
torch.save((self.data, self.slices), self.processed_paths[0])
def len(self):
return self.data.num_graphs if self.data else 0
def get(self, idx):
data = self.data.__class__()
for key in self.data.keys:
item, slices = self.data[key], self.slices[key]
s = list(slice(slices[idx], slices[idx + 1]))
data[key] = item[s]
return data
def save(self):
torch.save((self.data, self.slices), self.processed_paths[0])
@staticmethod
def load(root):
return CustomGraphDataset(root=root)
# Usage
loaded_dataset = CustomGraphDataset.load('./data_set_toy')
for i in loaded_dataset:
print(i)