-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
56 lines (44 loc) · 1.31 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import h5py
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
class Hdf5Dataset(Dataset):
def __init__(self, data_path, x_key, y_key):
"""
Initialize dataset
"""
# get data
data_file = h5py.File(data_path, 'r')
self.x = data_file[x_key]
self.y = data_file[y_key]
self.N = self.x.shape[0]
# transform data
self.transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
def __len__(self):
"""
Number of data in the dataset
"""
return self.N
def __getitem__(self, index):
"""
Return item from dataset
"""
image = self.x[index]
label = self.y[index]
return self.transform(image), torch.from_numpy(label).long()
def get_loader(data_path, x_key, y_key, batch_size, mode='train'):
"""
Get dataset loader
"""
dataset = Hdf5Dataset(data_path, x_key, y_key)
shuffle = False
if mode == 'train':
shuffle = True
data_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle)
return data_loader