Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update PSPNet and ICNet #81

Merged
merged 13 commits into from
Apr 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions ptsemseg/loader/ade20k_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from ptsemseg.utils import recursive_glob

class ADE20KLoader(data.Dataset):
def __init__(self, root, split="training", is_transform=False, img_size=512):
def __init__(self, root, split="training", is_transform=False, img_size=512, augmentations=None, img_norm=True):
self.root = root
self.split = split
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.n_classes = 150
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([104.00699, 116.66877, 122.67892])
Expand All @@ -37,21 +39,25 @@ def __getitem__(self, index):
lbl = m.imread(lbl_path)
lbl = np.array(lbl, dtype=np.int32)

if self.augmentations is not None:
img, lbl = self.augmentations(img, lbl)

if self.is_transform:
img, lbl = self.transform(img, lbl)

return img, lbl


def transform(self, img, lbl):
img = img[:, :, ::-1]
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean
img = m.imresize(img, (self.img_size[0], self.img_size[1]))
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCWH
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)

lbl = self.encode_segmap(lbl)
Expand Down
11 changes: 8 additions & 3 deletions ptsemseg/loader/camvid_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

class camvidLoader(data.Dataset):
def __init__(self, root, split="train",
is_transform=False, img_size=None, augmentations=None):
is_transform=False, img_size=None, augmentations=None, img_norm=True):
self.root = root
self.split = split
self.img_size = [360, 480]
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.n_classes = 12
self.files = collections.defaultdict(list)
Expand Down Expand Up @@ -48,10 +49,14 @@ def __getitem__(self, index):
return img, lbl

def transform(self, img, lbl):
img = img[:, :, ::-1]
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean
img = img.astype(float) / 255.0
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)

Expand Down
25 changes: 14 additions & 11 deletions ptsemseg/loader/cityscapes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class cityscapesLoader(data.Dataset):
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[ 0, 130, 180],
[ 0, 130, 180],
[220, 20, 60],
[255, 0, 0],
[ 0, 0, 142],
Expand All @@ -43,8 +43,10 @@ class cityscapesLoader(data.Dataset):

label_colours = dict(zip(range(19), colors))

mean_rgb = {'pascal': [103.939, 116.779, 123.68], 'cityscapes': [73.15835921, 82.90891754, 72.39239876]} # pascal mean for PSPNet and ICNet pre-trained model

def __init__(self, root, split="train", is_transform=False,
img_size=(512, 1024), augmentations=None):
img_size=(512, 1024), augmentations=None, img_norm=True, version='pascal'):
"""__init__

:param root:
Expand All @@ -57,9 +59,10 @@ def __init__(self, root, split="train", is_transform=False,
self.split = split
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.n_classes = 19
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([73.15835921, 82.90891754, 72.39239876])
self.mean = np.array(self.mean_rgb[version])
self.files = {}

self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
Expand Down Expand Up @@ -116,21 +119,21 @@ def transform(self, img, lbl):
:param img:
:param lbl:
"""
img = img[:, :, ::-1]
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean
img = m.imresize(img, (self.img_size[0], self.img_size[1]))
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCWH
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)

classes = np.unique(lbl)
lbl = lbl.astype(float)
lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), 'nearest', mode='F')
lbl = lbl.astype(int)


if not np.all(classes == np.unique(lbl)):
print("WARN: resizing labels yielded fewer classes")
Expand Down
20 changes: 13 additions & 7 deletions ptsemseg/loader/mit_sceneparsing_benchmark_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class MITSceneParsingBenchmarkLoader(data.Dataset):
https://github.com/CSAILVision/placeschallenge/tree/master/sceneparsing

"""
def __init__(self, root, split="training", is_transform=False, img_size=512):
def __init__(self, root, split="training", is_transform=False, img_size=512, augmentations=None, img_norm=True):
"""__init__

:param root:
Expand All @@ -34,6 +34,8 @@ def __init__(self, root, split="training", is_transform=False, img_size=512):
self.root = root
self.split = split
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.n_classes = 151 # 0 is reserved for "other"
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([104.00699, 116.66877, 122.67892])
Expand Down Expand Up @@ -67,6 +69,9 @@ def __getitem__(self, index):
lbl = m.imread(lbl_path)
lbl = np.array(lbl, dtype=np.uint8)

if self.augmentations is not None:
img, lbl = self.augmentations(img, lbl)

if self.is_transform:
img, lbl = self.transform(img, lbl)

Expand All @@ -78,14 +83,15 @@ def transform(self, img, lbl):
:param img:
:param lbl:
"""
img = img[:, :, ::-1]
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean
img = m.imresize(img, (self.img_size[0], self.img_size[1]))
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCWH
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)

classes = np.unique(lbl)
Expand Down
16 changes: 9 additions & 7 deletions ptsemseg/loader/nyuv2_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ class NYUv2Loader(data.Dataset):
"""


def __init__(self, root, split="training", is_transform=False, img_size=(480,640), augmentations=None):
def __init__(self, root, split="training", is_transform=False, img_size=(480,640), augmentations=None, img_norm=True):
self.root = root
self.is_transform = is_transform
self.n_classes = 14
self.augmentations = augmentations
self.img_norm = img_norm
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.files = collections.defaultdict(list)
Expand Down Expand Up @@ -71,14 +72,15 @@ def __getitem__(self, index):


def transform(self, img, lbl):
img = img[:, :, ::-1]
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean
img = m.imresize(img, (self.img_size[0], self.img_size[1]))
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCWH
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)

classes = np.unique(lbl)
Expand Down
16 changes: 9 additions & 7 deletions ptsemseg/loader/pascal_voc_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ class pascalVOCLoader(data.Dataset):
rather than VOC 2011) - 904 images
"""
def __init__(self, root, split='train_aug', is_transform=False,
img_size=512, augmentations=None):
img_size=512, augmentations=None, img_norm=True):
self.root = os.path.expanduser(root)
self.split = split
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.n_classes = 21
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.files = collections.defaultdict(list)
Expand Down Expand Up @@ -89,14 +90,15 @@ def __getitem__(self, index):


def transform(self, img, lbl):
img = img[:, :, ::-1]
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean
img = m.imresize(img, (self.img_size[0], self.img_size[1]))
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCWH
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)

lbl[lbl==255] = 0
Expand Down
16 changes: 9 additions & 7 deletions ptsemseg/loader/sunrgbd_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ class SUNRGBDLoader(data.Dataset):
test and train labels source: https://github.com/ankurhanda/sunrgbd-meta-data/raw/master/sunrgbd_train_test_labels.tar.gz
"""

def __init__(self, root, split="training", is_transform=False, img_size=(480, 640), augmentations=None):
def __init__(self, root, split="training", is_transform=False, img_size=(480, 640), augmentations=None, img_norm=True):
self.root = root
self.is_transform = is_transform
self.n_classes = 38
self.augmentations = augmentations
self.img_norm = img_norm
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.files = collections.defaultdict(list)
Expand Down Expand Up @@ -78,14 +79,15 @@ def __getitem__(self, index):


def transform(self, img, lbl):
img = img[:, :, ::-1]
img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean
img = m.imresize(img, (self.img_size[0], self.img_size[1]))
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCWH
if self.img_norm:
# Resize scales images from 0 to 255, thus we need
# to divide by 255.0
img = img.astype(float) / 255.0
# NHWC -> NCHW
img = img.transpose(2, 0, 1)

classes = np.unique(lbl)
Expand Down
28 changes: 27 additions & 1 deletion ptsemseg/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,21 @@

def cross_entropy2d(input, target, weight=None, size_average=True):
n, c, h, w = input.size()
nt, ht, wt = target.size()

# Handle inconsistent size between input and target
if h > ht and w > wt: # upsample labels
target = target.unsequeeze(1)
target = F.upsample(target, size=(h, w), mode='nearest')
target = target.sequeeze(1)
elif h < ht and w < wt: # upsample images
input = F.upsample(input, size=(ht, wt), mode='bilinear')
elif h != ht and w != wt:
raise Exception("Only support upsampling")

log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0]
log_p = log_p[target.view(-1, 1).repeat(1, c) >= 0]
log_p = log_p.view(-1, c)

mask = target >= 0
Expand Down Expand Up @@ -48,3 +60,17 @@ def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True)
weight=weight,
size_average=size_average)
return loss / float(batch_size)


def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None):
# Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16]
if scale_weight == None: # scale_weight: torch tensor type
n_inp = len(input)
scale = 0.4
scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp))

loss = 0.0
for i, inp in enumerate(input):
loss = loss + scale_weight[i] * cross_entropy2d(input=inp, target=target, weight=weight, size_average=size_average)

return loss
15 changes: 13 additions & 2 deletions ptsemseg/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from ptsemseg.models.segnet import *
from ptsemseg.models.unet import *
from ptsemseg.models.pspnet import *
from ptsemseg.models.icnet import *
from ptsemseg.models.linknet import *
from ptsemseg.models.frrn import *


def get_model(name, n_classes):
def get_model(name, n_classes, version=None):
model = _get_model_instance(name)

if name in ['frrnA', 'frrnB']:
Expand All @@ -30,7 +31,15 @@ def get_model(name, n_classes):
is_batchnorm=True,
in_channels=3,
is_deconv=True)


elif name == 'pspnet':
model = model(n_classes=n_classes, version=version)

elif name == 'icnet':
model = model(n_classes=n_classes, with_bn=False, version=version)
elif name == 'icnetBN':
model = model(n_classes=n_classes, with_bn=True, version=version)

else:
model = model(n_classes=n_classes)

Expand All @@ -45,6 +54,8 @@ def _get_model_instance(name):
'unet': unet,
'segnet': segnet,
'pspnet': pspnet,
'icnet': icnet,
'icnetBN': icnet,
'linknet': linknet,
'frrnA': frrn,
'frrnB': frrn,
Expand Down
Loading