-
Notifications
You must be signed in to change notification settings - Fork 0
/
densepass.py
126 lines (107 loc) · 4.96 KB
/
densepass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Prepare DensePASS dataset"""
import logging
import os
import random
import numpy as np
import torch
from PIL import Image
from torch.utils import data
from segmentron.data.dataloader.seg_data_base import SegmentationDataset
class DensePASSSegmentation(SegmentationDataset):
"""DensePASS Semantic Segmentation Dataset."""
NUM_CLASS = 19
def __init__(self, root='/home/jjiang/datasets/DensePASS/DensePASS', split='val',
mode=None, transform=None, **kwargs):
super(DensePASSSegmentation, self).__init__(root, split, mode, transform, **kwargs)
assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/DensePASS"
self.images, self.mask_paths = _get_city_pairs(self.root, self.split)
self.crop_size = [400, 2048] # for inference only
assert (len(self.images) == len(self.mask_paths))
if len(self.images) == 0:
raise RuntimeError("Found 0 images in subfolders of:" + root + "\n")
self._key = np.array([-1, -1, -1, -1, -1, -1,
-1, -1, 0, 1, -1, -1,
2, 3, 4, -1, -1, -1,
5, -1, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15,
-1, -1, 16, 17, 18])
self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
def _class_to_index(self, mask):
values = np.unique(mask)
for value in values:
assert (value in self._mapping)
index = np.digitize(mask.ravel(), self._mapping, right=True)
return self._key[index].reshape(mask.shape)
def _val_sync_transform_resize(self, img, mask):
w, h = img.size
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def __getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
if self.mode == 'test':
if self.transform is not None:
img = self.transform(img)
return img, os.path.basename(self.images[index])
mask = Image.open(self.mask_paths[index])
if self.mode == 'train':
img, mask = self._sync_transform(img, mask, resize=True)
elif self.mode == 'val':
img, mask = self._val_sync_transform_resize(img, mask)
else:
assert self.mode == 'testval'
img, mask = self._val_sync_transform_resize(img, mask)
if self.transform is not None:
img = self.transform(img)
return img, mask, os.path.basename(self.images[index])
def _mask_transform(self, mask):
return torch.LongTensor(np.array(mask).astype('int32'))
def __len__(self):
return len(self.images)
@property
def pred_offset(self):
return 0
@property
def classes(self):
"""Category names."""
return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'train', 'motorcycle', 'bicycle')
def _get_city_pairs(folder, split='train'):
def get_path_pairs(img_folder, mask_folder):
img_paths = []
mask_paths = []
for root, _, files in os.walk(img_folder):
for filename in files:
if filename.startswith('._'):
continue
if filename.endswith('.png'):
imgpath = os.path.join(root, filename)
foldername = os.path.basename(os.path.dirname(imgpath))
maskname = filename.replace('_.png', '_labelTrainIds.png')
maskpath = os.path.join(mask_folder, foldername, maskname)
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
img_paths.append(imgpath)
mask_paths.append(maskpath)
else:
logging.info('cannot find the mask or image:', imgpath, maskpath)
logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder))
return img_paths, mask_paths
if split in ('train', 'val'):
img_folder = os.path.join(folder, 'leftImg8bit/' + split)
mask_folder = os.path.join(folder, 'gtFine/' + split)
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
return img_paths, mask_paths
else:
assert split == 'test'
logging.info('test set, but only val set')
val_img_folder = os.path.join(folder, 'leftImg8bit/val')
val_mask_folder = os.path.join(folder, 'gtFine/val')
img_paths, mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
return img_paths, mask_paths
if __name__ == '__main__':
dst = DensePASSSegmentation(split='train', mode='train')
trainloader = data.DataLoader(dst, batch_size=1)
for i, data in enumerate(trainloader):
imgs, labels, *args = data
break