-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcityscapes.py
125 lines (105 loc) · 4.99 KB
/
cityscapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Prepare Cityscapes dataset"""
import logging
import os
import random
import numpy as np
import torch
from PIL import Image
from .seg_data_base import SegmentationDataset
class CitySegmentation(SegmentationDataset):
"""Cityscapes Semantic Segmentation Dataset."""
NUM_CLASS = 19
def __init__(self, root='/home/jjiang/datasets/Cityscapes', split='train', mode=None, transform=None, **kwargs):
super(CitySegmentation, self).__init__(root, split, mode, transform, **kwargs)
assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/cityscapes"
self.images, self.mask_paths = _get_city_pairs(self.root, self.split)
assert (len(self.images) == len(self.mask_paths))
if len(self.images) == 0:
raise RuntimeError("Found 0 images in subfolders of:" + root + "\n")
self._key = np.array([-1, -1, -1, -1, -1, -1,
-1, -1, 0, 1, -1, -1,
2, 3, 4, -1, -1, -1,
5, -1, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15,
-1, -1, 16, 17, 18])
self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
def _class_to_index(self, mask):
mask[mask == 255] = -1
values = np.unique(mask)
for value in values:
assert (value in self._mapping)
index = np.digitize(mask.ravel(), self._mapping, right=True)
return self._key[index].reshape(mask.shape)
def _val_sync_transform_resize(self, img, mask):
w, h = img.size
x1 = random.randint(0, w - self.crop_size[1])
y1 = random.randint(0, h - self.crop_size[0])
img = img.crop((x1, y1, x1 + self.crop_size[1], y1 + self.crop_size[0]))
mask = mask.crop((x1, y1, x1 + self.crop_size[1], y1 + self.crop_size[0]))
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def __getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
if self.mode == 'test':
if self.transform is not None:
img = self.transform(img)
return img, os.path.basename(self.images[index])
mask = Image.open(self.mask_paths[index])
if self.mode == 'train':
img, mask = self._sync_transform(img, mask, resize=True)
elif self.mode == 'val':
img, mask = self._val_sync_transform_resize(img, mask)
else:
assert self.mode == 'testval'
img, mask = self._val_sync_transform_resize(img, mask)
if self.transform is not None:
img = self.transform(img)
return img, mask, os.path.basename(self.images[index])
def _mask_transform(self, mask):
target = self._class_to_index(np.array(mask).astype('int32'))
return torch.LongTensor(np.array(target).astype('int32'))
def __len__(self):
return len(self.images)
@property
def pred_offset(self):
return 0
@property
def classes(self):
"""Category names."""
return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'train', 'motorcycle', 'bicycle')
def _get_city_pairs(folder, split='train'):
def get_path_pairs(img_folder, mask_folder):
img_paths = []
mask_paths = []
for root, _, files in os.walk(img_folder):
for filename in files:
if filename.startswith('._'):
continue
if filename.endswith('.png'):
imgpath = os.path.join(root, filename)
foldername = os.path.basename(os.path.dirname(imgpath))
maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
maskpath = os.path.join(mask_folder, foldername, maskname)
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
img_paths.append(imgpath)
mask_paths.append(maskpath)
else:
logging.info('cannot find the mask or image:', imgpath, maskpath)
logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder))
return img_paths, mask_paths
if split in ('train', 'val'):
img_folder = os.path.join(folder, 'leftImg8bit/' + split)
mask_folder = os.path.join(folder, 'gtFine/' + split)
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
return img_paths, mask_paths
else:
assert split == 'test'
logging.info('test set, but only val set')
val_img_folder = os.path.join(folder, 'leftImg8bit/val')
val_mask_folder = os.path.join(folder, 'gtFine/val')
img_paths, mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
return img_paths, mask_paths
if __name__ == '__main__':
dataset = CitySegmentation()