-
Notifications
You must be signed in to change notification settings - Fork 128
/
data_loader.py
38 lines (31 loc) · 1.28 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
from PIL import Image
from torch.utils.data import Dataset
class KittiLoader(Dataset):
def __init__(self, root_dir, mode, transform=None):
left_dir = os.path.join(root_dir, 'image_02/data/')
self.left_paths = sorted([os.path.join(left_dir, fname) for fname\
in os.listdir(left_dir)])
if mode == 'train':
right_dir = os.path.join(root_dir, 'image_03/data/')
self.right_paths = sorted([os.path.join(right_dir, fname) for fname\
in os.listdir(right_dir)])
assert len(self.right_paths) == len(self.left_paths)
self.transform = transform
self.mode = mode
def __len__(self):
return len(self.left_paths)
def __getitem__(self, idx):
left_image = Image.open(self.left_paths[idx])
if self.mode == 'train':
right_image = Image.open(self.right_paths[idx])
sample = {'left_image': left_image, 'right_image': right_image}
if self.transform:
sample = self.transform(sample)
return sample
else:
return sample
else:
if self.transform:
left_image = self.transform(left_image)
return left_image