From 7d5bb55bb1f3999dceae0c4f500b19b184ca5013 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Thu, 19 Apr 2018 02:41:15 +0800 Subject: [PATCH 01/13] Add args for mean version and img_norm, and input type for m.imresize is uint8 with RGB mode --- ptsemseg/loader/cityscapes_loader.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/ptsemseg/loader/cityscapes_loader.py b/ptsemseg/loader/cityscapes_loader.py index 5686112c..2f60b8af 100644 --- a/ptsemseg/loader/cityscapes_loader.py +++ b/ptsemseg/loader/cityscapes_loader.py @@ -31,7 +31,7 @@ class cityscapesLoader(data.Dataset): [220, 220, 0], [107, 142, 35], [152, 251, 152], - [ 0, 130, 180], + [ 0, 130, 180], [220, 20, 60], [255, 0, 0], [ 0, 0, 142], @@ -43,8 +43,10 @@ class cityscapesLoader(data.Dataset): label_colours = dict(zip(range(19), colors)) + mean_rgb = {'pascal': [103.939, 116.779, 123.68], 'cityscapes': [73.15835921, 82.90891754, 72.39239876]} # pascal mean for PSPNet and ICNet pre-trained model + def __init__(self, root, split="train", is_transform=False, - img_size=(512, 1024), augmentations=None): + img_size=(512, 1024), augmentations=None, img_norm=True, version='pascal'): """__init__ :param root: @@ -57,9 +59,10 @@ def __init__(self, root, split="train", is_transform=False, self.split = split self.is_transform = is_transform self.augmentations = augmentations + self.img_norm = img_norm self.n_classes = 19 self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) - self.mean = np.array([73.15835921, 82.90891754, 72.39239876]) + self.mean = np.array(self.mean_rgb[version]) self.files = {} self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) @@ -116,21 +119,21 @@ def transform(self, img, lbl): :param img: :param lbl: """ - img = img[:, :, ::-1] + img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode + img = img[:, :, ::-1] # RGB -> BGR img = img.astype(np.float64) img -= self.mean - img = m.imresize(img, (self.img_size[0], self.img_size[1])) - # Resize scales images from 0 to 255, thus we need - # to divide by 255.0 - img = img.astype(float) / 255.0 - # NHWC -> NCWH + if self.img_norm: + # Resize scales images from 0 to 255, thus we need + # to divide by 255.0 + img = img.astype(float) / 255.0 + # NHWC -> NCHW img = img.transpose(2, 0, 1) - + classes = np.unique(lbl) lbl = lbl.astype(float) lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), 'nearest', mode='F') lbl = lbl.astype(int) - if not np.all(classes == np.unique(lbl)): print("WARN: resizing labels yielded fewer classes") From d91cbfbd75f737f2fef090d062012fa2fe381b89 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Thu, 19 Apr 2018 02:46:00 +0800 Subject: [PATCH 02/13] Add args: img_norm, include_flip_mode, measure_time, and fix typo --- test.py | 32 ++++++++++++++++++++++---------- train.py | 24 ++++++++++++++++++------ validate.py | 48 +++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 79 insertions(+), 25 deletions(-) diff --git a/test.py b/test.py index e53b10df..4f3dace2 100644 --- a/test.py +++ b/test.py @@ -1,10 +1,10 @@ -import sys +import sys, os import torch import visdom import argparse import numpy as np -import torch.nn as nn import scipy.misc as misc +import torch.nn as nn import torch.nn.functional as F import torchvision.models as models @@ -30,18 +30,19 @@ def test(args): data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) - loader = data_loader(data_path, is_transform=True) + loader = data_loader(data_path, is_transform=True, img_norm=args.img_norm) n_classes = loader.n_classes resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp='bicubic') + img = misc.imresize(img, (loader.img_size[0], loader.img_size[1])) img = img[:, :, ::-1] img = img.astype(np.float64) img -= loader.mean - img = misc.imresize(img, (loader.img_size[0], loader.img_size[1])) - img = img.astype(float) / 255.0 - # NHWC -> NCWH - img = img.transpose(2, 0, 1) + if args.img_norm: + img = img.astype(float) / 255.0 + # NHWC -> NCHW + img = img.transpose(2, 0, 1) img = np.expand_dims(img, 0) img = torch.from_numpy(img).float() @@ -56,7 +57,7 @@ def test(args): outputs = F.softmax(model(images), dim=1) - if args.dcrf == "True": + if args.dcrf: unary = outputs.data.cpu().numpy() unary = np.squeeze(unary, 0) unary = -np.log(unary) @@ -96,8 +97,19 @@ def test(args): help='Path to the saved model') parser.add_argument('--dataset', nargs='?', type=str, default='pascal', help='Dataset to use [\'pascal, camvid, ade20k etc\']') - parser.add_argument('--dcrf', nargs='?', type=str, default="False", - help='Enable DenseCRF based post-processing') + + parser.add_argument('--img_norm', dest='img_norm', action='store_true', + help='Enable input image scales normalization [0, 1] | True by default') + parser.add_argument('--no-img_norm', dest='img_norm', action='store_false', + help='Disable input image scales normalization [0, 1] | True by default') + parser.set_defaults(img_norm=True) + + parser.add_argument('--dcrf', dest='dcrf', action='store_true', + help='Enable DenseCRF based post-processing | False by default') + parser.add_argument('--no-dcrf', dest='dcrf', action='store_false', + help='Disable DenseCRF based post-processing | False by default') + parser.set_defaults(dcrf=False) + parser.add_argument('--img_path', nargs='?', type=str, default=None, help='Path of the input image') parser.add_argument('--out_path', nargs='?', type=str, default=None, diff --git a/train.py b/train.py index 66561d7a..ae33c670 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,4 @@ -import sys +import sys, os import torch import visdom import argparse @@ -26,8 +26,8 @@ def train(args): # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) - t_loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols), augmentations=data_aug) - v_loader = data_loader(data_path, is_transform=True, split='val', img_size=(args.img_rows, args.img_cols)) + t_loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols), augmentations=data_aug, img_norm=args.img_norm) + v_loader = data_loader(data_path, is_transform=True, split='val', img_size=(args.img_rows, args.img_cols), img_norm=args.img_norm) n_classes = t_loader.n_classes trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8, shuffle=True) @@ -132,7 +132,14 @@ def train(args): parser.add_argument('--img_rows', nargs='?', type=int, default=256, help='Height of the input image') parser.add_argument('--img_cols', nargs='?', type=int, default=256, - help='Height of the input image') + help='Width of the input image') + + parser.add_argument('--img_norm', dest='img_norm', action='store_true', + help='Enable input image scales normalization [0, 1] | True by default') + parser.add_argument('--no-img_norm', dest='img_norm', action='store_false', + help='Disable input image scales normalization [0, 1] | True by default') + parser.set_defaults(img_norm=True) + parser.add_argument('--n_epoch', nargs='?', type=int, default=100, help='# of the epochs') parser.add_argument('--batch_size', nargs='?', type=int, default=1, @@ -143,7 +150,12 @@ def train(args): help='Divider for # of features to use') parser.add_argument('--resume', nargs='?', type=str, default=None, help='Path to previous saved model to restart from') - parser.add_argument('--visdom', nargs='?', type=bool, default=False, - help='Show visualization(s) on visdom | False by default') + + parser.add_argument('--visdom', dest='visdom', action='store_true', + help='Enable visualization(s) on visdom | False by default') + parser.add_argument('--no-visdom', dest='visdom', action='store_false', + help='Disable visualization(s) on visdom | False by default') + parser.set_defaults(visdom=False) + args = parser.parse_args() train(args) diff --git a/validate.py b/validate.py index 46cfa0cb..2a16821a 100644 --- a/validate.py +++ b/validate.py @@ -1,8 +1,10 @@ -import sys +import sys, os import torch import visdom import argparse +import timeit import numpy as np +import scipy.misc as misc import torch.nn as nn import torch.nn.functional as F import torchvision.models as models @@ -23,30 +25,38 @@ cudnn.benchmark = True def validate(args): + model_file_name = os.path.split(args.model_path)[1] + model_name = model_file_name[:model_file_name.find('_')] # Setup Dataloader data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) - loader = data_loader(data_path, split=args.split, is_transform=True, img_size=(args.img_rows, args.img_cols)) + loader = data_loader(data_path, split=args.split, is_transform=True, img_size=(args.img_rows, args.img_cols), img_norm=args.img_norm) n_classes = loader.n_classes valloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4) running_metrics = runningScore(n_classes) # Setup Model - model = get_model(args.model_path[:args.model_path.find('_')], n_classes) + model = get_model(model_name, n_classes, version=args.dataset) state = convert_state_dict(torch.load(args.model_path)['model_state']) model.load_state_dict(state) model.eval() + model.cuda() + + for i, (images, labels) in enumerate(valloader): + start_time = timeit.default_timer() - for i, (images, labels) in tqdm(enumerate(valloader)): - model.cuda() images = Variable(images.cuda(), volatile=True) - labels = Variable(labels.cuda(), volatile=True) + #labels = Variable(labels.cuda(), volatile=True) outputs = model(images) pred = outputs.data.max(1)[1].cpu().numpy() - gt = labels.data.cpu().numpy() - + #gt = labels.data.cpu().numpy() + gt = labels.numpy() + + if args.measure_time: + elapsed_time = timeit.default_timer() - start_time + print('Inference time (iter {0:5d}): {1:3.5f} fps'.format(i+1, pred.shape[0]/elapsed_time)) running_metrics.update(gt, pred) score, class_iou = running_metrics.get_scores() @@ -66,10 +76,30 @@ def validate(args): parser.add_argument('--img_rows', nargs='?', type=int, default=256, help='Height of the input image') parser.add_argument('--img_cols', nargs='?', type=int, default=256, - help='Height of the input image') + help='Width of the input image') + + parser.add_argument('--img_norm', dest='img_norm', action='store_true', + help='Enable input image scales normalization [0, 1] | True by default') + parser.add_argument('--no-img_norm', dest='img_norm', action='store_false', + help='Disable input image scales normalization [0, 1] | True by default') + parser.set_defaults(img_norm=True) + + parser.add_argument('--include_flip_mode', dest='include_flip_mode', action='store_true', + help='Enable evaluation with flipped image | True by default') + parser.add_argument('--no-include_flip_mode', dest='include_flip_mode', action='store_false', + help='Disable evaluation with flipped image | True by default') + parser.set_defaults(include_flip_mode=True) + parser.add_argument('--batch_size', nargs='?', type=int, default=1, help='Batch Size') parser.add_argument('--split', nargs='?', type=str, default='val', help='Split of dataset to test on') + + parser.add_argument('--measure_time', dest='measure_time', action='store_true', + help='Enable evaluation with time (fps) measurement | True by default') + parser.add_argument('--no-measure_time', dest='measure_time', action='store_false', + help='Disable evaluation with time (fps) measurement | True by default') + parser.set_defaults(measure_time=True) + args = parser.parse_args() validate(args) From 30fec6c418d9bff7fca61a0aacba646f7ebbacf2 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Thu, 19 Apr 2018 03:03:19 +0800 Subject: [PATCH 03/13] Add pspnet with version arg --- ptsemseg/models/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ptsemseg/models/__init__.py b/ptsemseg/models/__init__.py index 2c92179c..f78ddf93 100644 --- a/ptsemseg/models/__init__.py +++ b/ptsemseg/models/__init__.py @@ -8,7 +8,7 @@ from ptsemseg.models.frrn import * -def get_model(name, n_classes): +def get_model(name, n_classes, version=None): model = _get_model_instance(name) if name in ['frrnA', 'frrnB']: @@ -30,7 +30,10 @@ def get_model(name, n_classes): is_batchnorm=True, in_channels=3, is_deconv=True) - + + elif name == 'pspnet': + model = model(n_classes=n_classes, version=version) + else: model = model(n_classes=n_classes) From 97db8be28586a392621f158a13742f53e4230d4e Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Thu, 19 Apr 2018 03:06:30 +0800 Subject: [PATCH 04/13] Add auxiliary training layers and modify test part --- ptsemseg/models/pspnet.py | 47 +++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/ptsemseg/models/pspnet.py b/ptsemseg/models/pspnet.py index 3cdc13a0..10f8b575 100644 --- a/ptsemseg/models/pspnet.py +++ b/ptsemseg/models/pspnet.py @@ -40,6 +40,7 @@ class pspnet(nn.Module): References: 1) Original Author's code: https://github.com/hszhao/PSPNet 2) Chainer implementation by @mitmul: https://github.com/mitmul/chainer-pspnet + 3) TensorFlow implementation by @hellochick: https://github.com/hellochick/PSPNet-tensorflow Visualization: http://dgschwend.github.io/netscope/#/gist/6bfb59e6a3cfcb4e2bb8d47f827c2928 @@ -82,6 +83,10 @@ def __init__(self, self.dropout = nn.Dropout2d(p=0.1, inplace=True) self.classification = nn.Conv2d(512, self.n_classes, 1, 1, 0) + # Auxiliary layers for training + self.convbnrelu4_aux = conv2DBatchNormRelu(in_channels=1024, k_size=3, n_filters=256, padding=1, stride=1, bias=False) + self.aux_cls = nn.Conv2d(256, self.n_classes, 1, 1, 0) + def forward(self, x): inp_shape = x.shape[2:] @@ -97,6 +102,12 @@ def forward(self, x): x = self.res_block2(x) x = self.res_block3(x) x = self.res_block4(x) + + # Auxiliary layers for training + x_aux = self.convbnrelu4_aux(x) + x_aux = self.dropout(x_aux) + x_aux = self.aux_cls(x_aux) + x = self.res_block5(x) x = self.pyramid_pooling(x) @@ -105,8 +116,8 @@ def forward(self, x): x = self.dropout(x) x = self.classification(x) - x = F.upsample(x, size=inp_shape, mode='bilinear') - return x + x = F.upsample(x, size=inp_shape, mode='bilinear') + return x_aux, x def load_pretrained_model(self, model_path): """ @@ -120,7 +131,7 @@ def load_pretrained_model(self, model_path): def _get_layer_params(layer, ltype): - if ltype == 'BNData': + if ltype == 'BNData': gamma = np.array(layer.blobs[0].data) beta = np.array(layer.blobs[1].data) mean = np.array(layer.blobs[2].data) @@ -178,7 +189,7 @@ def _transfer_conv(layer_name, module): print("CONV {}: Original {} and trans weights {}".format(layer_name, w_shape, weights.shape)) - + module.weight.data.copy_(torch.from_numpy(weights).view_as(module.weight)) if len(bias) != 0: @@ -186,7 +197,7 @@ def _transfer_conv(layer_name, module): print("CONV {}: Original {} and trans bias {}".format(layer_name, b_shape, bias.shape)) - module.bias.data.copy_(torch.from_numpy(bias)) + module.bias.data.copy_(torch.from_numpy(bias).view_as(module.bias)) def _transfer_conv_bn(conv_layer_name, mother_module): @@ -234,7 +245,8 @@ def _transfer_residual(prefix, block): 'conv5_3_pool3_conv': self.pyramid_pooling.paths[1].cbr_unit, 'conv5_3_pool2_conv': self.pyramid_pooling.paths[2].cbr_unit, 'conv5_3_pool1_conv': self.pyramid_pooling.paths[3].cbr_unit, - 'conv5_4': self.cbr_final.cbr_unit,} + 'conv5_4': self.cbr_final.cbr_unit, + 'conv4_' + str(self.block_config[2]+1): self.convbnrelu4_aux.cbr_unit,} # Auxiliary layers for training residual_layers = {'conv2': [self.res_block2, self.block_config[0]], 'conv3': [self.res_block3, self.block_config[1]], @@ -247,6 +259,7 @@ def _transfer_residual(prefix, block): # Transfer weights for final non-bn conv layer _transfer_conv('conv6', self.classification) + _transfer_conv('conv6_1', self.aux_cls) # Transfer weights for all residual layers for k, v in residual_layers.items(): @@ -314,11 +327,13 @@ def tile_predict(self, img): import matplotlib.pyplot as plt import scipy.misc as m from ptsemseg.loader.cityscapes_loader import cityscapesLoader as cl - psp = pspnet(version='ade20k') + psp = pspnet(version='cityscapes') # Just need to do this one time - #psp.load_pretrained_model(model_path='/home/meet/models/pspnet101_cityscapes.caffemodel') - psp.load_pretrained_model(model_path='/home/meet/models/pspnet50_ADE20K.caffemodel') + caffemodel_dir_path = 'PATH_TO_PSPNET_DIR/evaluation/model' + psp.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 'pspnet101_cityscapes.caffemodel')) + #psp.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 'pspnet50_ADE20K.caffemodel')) + #psp.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 'pspnet101_VOC2012.caffemodel')) # psp.load_state_dict(torch.load('psp.pth')) @@ -326,8 +341,9 @@ def tile_predict(self, img): psp.cuda(cd) psp.eval() - dst = cl(root='/home/meet/datasets/cityscapes/') - img = m.imread('/home/meet/seg/leftImg8bit/demoVideo/stuttgart_00/stuttgart_00_000000_000010_leftImg8bit.png') + dataset_root_dir = 'PATH_TO_CITYSCAPES_DIR' + dst = cl(root=dataset_root_dir) + img = m.imread(os.path.join(dataset_root_dir, 'leftImg8bit/demoVideo/stuttgart_00/stuttgart_00_000000_000010_leftImg8bit.png')) m.imsave('cropped.png', img) orig_size = img.shape[:-1] img = img.transpose(2, 0, 1) @@ -342,5 +358,12 @@ def tile_predict(self, img): # m.imsave('ade20k_sttutgart_tiled.png', decoded) m.imsave('ade20k_sttutgart_tiled.png', pred) - torch.save(psp.state_dict(), "psp_ade20k.pth") + checkpoints_dir_path = 'checkpoints' + if not os.path.exists(checkpoints_dir_path): + os.mkdir(checkpoints_dir_path) + psp = torch.nn.DataParallel(psp, device_ids=range(torch.cuda.device_count())) # append `module.` + state = {'model_state': psp.state_dict()} + torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_101_cityscapes.pth")) + #torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_50_ade20k.pth")) + #torch.save(state, os.path.join(checkpoints_dir_path, "pspnet_101_pascalvoc.pth")) print("Output Shape {} \t Input Shape {}".format(out.shape, img.shape)) From b451491f3d09d02e18a91573c265722b4fa9fd58 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Thu, 19 Apr 2018 03:18:05 +0800 Subject: [PATCH 05/13] Modify tile_predict with flip arg and support for batch with tensor type --- ptsemseg/models/pspnet.py | 66 ++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/ptsemseg/models/pspnet.py b/ptsemseg/models/pspnet.py index 10f8b575..88e671bc 100644 --- a/ptsemseg/models/pspnet.py +++ b/ptsemseg/models/pspnet.py @@ -266,29 +266,31 @@ def _transfer_residual(prefix, block): _transfer_residual(k, v) - def tile_predict(self, img): + def tile_predict(self, imgs, include_flip_mode=True): """ Predict by takin overlapping tiles from the image. - Strides are adaptively computed from the img shape + Strides are adaptively computed from the imgs shape and input size - :param img: np.ndarray with shape [C, H, W] in BGR format + :param imgs: torch.Tensor with shape [N, C, H, W] in BGR format :param side: int with side length of model input :param n_classes: int with number of classes in seg output. """ - side = self.input_size[0] + side_x, side_y = self.input_size n_classes = self.n_classes - h, w = img.shape[1:] - n = int(max(h,w) / float(side) + 1) - stride_x = ( h - side ) / float(n) - stride_y = ( w - side ) / float(n) + n_samples, c, h, w = imgs.shape + #n = int(max(h,w) / float(side) + 1) + n_x = int(h / float(side_x) + 1) + n_y = int(w / float(side_y) + 1) + stride_x = ( h - side_x ) / float(n_x) + stride_y = ( w - side_y ) / float(n_y) - x_ends = [[int(i*stride_x), int(i*stride_x) + side] for i in range(n+1)] - y_ends = [[int(i*stride_y), int(i*stride_y) + side] for i in range(n+1)] + x_ends = [[int(i*stride_x), int(i*stride_x) + side_x] for i in range(n_x+1)] + y_ends = [[int(i*stride_y), int(i*stride_y) + side_y] for i in range(n_y+1)] - pred = np.zeros([1, n_classes, h, w]) + pred = np.zeros([n_samples, n_classes, h, w]) count = np.zeros([h, w]) slice_count = 0 @@ -296,33 +298,40 @@ def tile_predict(self, img): for sy, ey in y_ends: slice_count += 1 - img_slice = img[:, sx:ex, sy:ey] - img_slice_flip = np.copy(img_slice[:,:,::-1]) - + imgs_slice = imgs[:, :, sx:ex, sy:ey] + if include_flip_mode: + imgs_slice_flip = torch.from_numpy(np.copy(imgs_slice.cpu().numpy()[:, :, :, ::-1])).float() + is_model_on_cuda = next(self.parameters()).is_cuda - inp = Variable(torch.unsqueeze(torch.from_numpy(img_slice).float(), 0), volatile=True) - flp = Variable(torch.unsqueeze(torch.from_numpy(img_slice_flip).float(), 0), volatile=True) + inp = Variable(imgs_slice, volatile=True) + if include_flip_mode: + flp = Variable(imgs_slice_flip, volatile=True) if is_model_on_cuda: inp = inp.cuda() - flp = flp.cuda() - - psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy() - psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy() - psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0 + if include_flip_mode: + flp = flp.cuda() + + psub1 = F.softmax(self.forward(inp)[-1], dim=1).data.cpu().numpy() + if include_flip_mode: + psub2 = F.softmax(self.forward(flp)[-1], dim=1).data.cpu().numpy() + psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0 + else: + psub = psub1 pred[:, :, sx:ex, sy:ey] = psub count[sx:ex, sy:ey] += 1.0 - score = (pred / count[None, None, ...]).astype(np.float32)[0] - return score / score.sum(axis=0) + score = (pred / count[None, None, ...]).astype(np.float32) + return score / np.expand_dims(score.sum(axis=1), axis=1) # For Testing Purposes only if __name__ == '__main__': cd = 0 + import os from torch.autograd import Variable import matplotlib.pyplot as plt import scipy.misc as m @@ -350,13 +359,14 @@ def tile_predict(self, img): img = img.astype(np.float64) img -= np.array([123.68, 116.779, 103.939])[:, None, None] img = np.copy(img[::-1, :, :]) - flp = np.copy(img[:, :, ::-1]) + img = torch.from_numpy(img).float() # convert to torch tensor + img = img.unsqueeze(0) out = psp.tile_predict(img) - pred = np.argmax(out, axis=0) - #decoded = dst.decode_segmap(pred) - # m.imsave('ade20k_sttutgart_tiled.png', decoded) - m.imsave('ade20k_sttutgart_tiled.png', pred) + pred = np.argmax(out, axis=1).astype(np.uint8)[0] + decoded = dst.decode_segmap(pred) + m.imsave('cityscapes_sttutgart_tiled.png', decoded) + #m.imsave('cityscapes_sttutgart_tiled.png', pred) checkpoints_dir_path = 'checkpoints' if not os.path.exists(checkpoints_dir_path): From 393c4b3197ef6b5b03605a818833abdbb08a9811 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Thu, 19 Apr 2018 03:33:36 +0800 Subject: [PATCH 06/13] Fix wrong n_blocks for bottoleNeckIdentityPSP in residualBlockPSP --- ptsemseg/models/utils.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/ptsemseg/models/utils.py b/ptsemseg/models/utils.py index 862d38b4..6883c010 100644 --- a/ptsemseg/models/utils.py +++ b/ptsemseg/models/utils.py @@ -375,17 +375,17 @@ def __init__(self, in_channels, mid_channels, out_channels, stride, dilation=1): super(bottleNeckPSP, self).__init__() - self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, 1, 0, bias=False) + self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, stride=1, padding=0, bias=False) if dilation > 1: - self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, 1, - padding=dilation, bias=False, - dilation=dilation) + self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, + stride=stride, padding=dilation, + bias=False, dilation=dilation) else: - self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, - stride=stride, padding=1, + self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, + stride=stride, padding=1, bias=False, dilation=1) - self.cb3 = conv2DBatchNorm(mid_channels, out_channels, 1, 1, 0, bias=False) - self.cb4 = conv2DBatchNorm(in_channels, out_channels, 1, stride, 0, bias=False) + self.cb3 = conv2DBatchNorm(mid_channels, out_channels, 1, stride=1, padding=0, bias=False) + self.cb4 = conv2DBatchNorm(in_channels, out_channels, 1, stride=stride, padding=0, bias=False) def forward(self, x): conv = self.cb3(self.cbr2(self.cbr1(x))) @@ -400,14 +400,14 @@ def __init__(self, in_channels, mid_channels, stride, dilation=1): self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, 1, 0, bias=False) if dilation > 1: - self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, 1, - padding=dilation, bias=False, - dilation=dilation) + self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, + stride=1, padding=dilation, + bias=False, dilation=dilation) else: - self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, + self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, stride=1, padding=1, bias=False, dilation=1) - self.cb3 = conv2DBatchNorm(mid_channels, in_channels, 1, 1, 0, bias=False) + self.cb3 = conv2DBatchNorm(mid_channels, in_channels, 1, stride=1, padding=0, bias=False) def forward(self, x): residual = x @@ -424,7 +424,7 @@ def __init__(self, n_blocks, in_channels, mid_channels, out_channels, stride, di stride = 1 layers = [bottleNeckPSP(in_channels, mid_channels, out_channels, stride, dilation)] - for i in range(n_blocks): + for i in range(n_blocks-1): layers.append(bottleNeckIdentifyPSP(out_channels, mid_channels, stride, dilation)) self.layers = nn.Sequential(*layers) From 24603dc2df960e439545f8dedd6c1d4ee0c27dc0 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Thu, 19 Apr 2018 03:56:33 +0800 Subject: [PATCH 07/13] Add pspnet support for training and testing --- ptsemseg/loss.py | 12 ++++++++++++ test.py | 40 +++++++++++++++++++++++++--------------- train.py | 11 ++++++++++- validate.py | 23 +++++++++++++++++++++-- 4 files changed, 68 insertions(+), 18 deletions(-) diff --git a/ptsemseg/loss.py b/ptsemseg/loss.py index 3bce0e1c..dc32c284 100644 --- a/ptsemseg/loss.py +++ b/ptsemseg/loss.py @@ -1,11 +1,23 @@ import torch import numpy as np +import scipy.misc as m import torch.nn as nn import torch.nn.functional as F +from torch.autograd import Variable + def cross_entropy2d(input, target, weight=None, size_average=True): n, c, h, w = input.size() + nt, ct, ht, wt = target.size() + + if h != ht or w != wt: # inconsistent size between input and target + lbl = target.data.cpu().numpy() + lbl = lbl.astype(float) + lbl = m.imresize(lbl, (h, w), 'nearest', mode='F') + lbl = lbl.astype(int) + target = Variable(torch.from_numpy(lbl).long().cuda()) + log_p = F.log_softmax(input, dim=1) log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] diff --git a/test.py b/test.py index 4f3dace2..85ca1acd 100644 --- a/test.py +++ b/test.py @@ -2,6 +2,7 @@ import torch import visdom import argparse +import timeit import numpy as np import scipy.misc as misc import torch.nn as nn @@ -23,11 +24,13 @@ CRF post-processing will not work") def test(args): + model_file_name = os.path.split(args.model_path)[1] + model_name = model_file_name[:model_file_name.find('_')] # Setup image print("Read Input Image from : {}".format(args.img_path)) img = misc.imread(args.img_path) - + data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) loader = data_loader(data_path, is_transform=True, img_norm=args.img_norm) @@ -35,7 +38,11 @@ def test(args): resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp='bicubic') - img = misc.imresize(img, (loader.img_size[0], loader.img_size[1])) + orig_size = img.shape[:-1] + if model_name == 'pspnet': + img = misc.imresize(img, (orig_size[0]//2*2+1, orig_size[1]//2*2+1)) # uint8 with RGB mode, resize width and height which are odd numbers + else: + img = misc.imresize(img, (loader.img_size[0], loader.img_size[1])) img = img[:, :, ::-1] img = img.astype(np.float64) img -= loader.mean @@ -47,16 +54,23 @@ def test(args): img = torch.from_numpy(img).float() # Setup Model - model = get_model(args.model_path[:args.model_path.find('_')], n_classes) + model = get_model(model_name, n_classes, version=args.dataset) state = convert_state_dict(torch.load(args.model_path)['model_state']) model.load_state_dict(state) model.eval() - - model.cuda(0) - images = Variable(img.cuda(0), volatile=True) - outputs = F.softmax(model(images), dim=1) - + if torch.cuda.is_available(): + model.cuda(0) + images = Variable(img.cuda(0), volatile=True) + else: + images = Variable(img, volatile=True) + + if model_name == 'pspnet': + outputs = model(images)[-1] + else: + outputs = model(images) + #outputs = F.softmax(outputs, dim=1) + if args.dcrf: unary = outputs.data.cpu().numpy() unary = np.squeeze(unary, 0) @@ -79,13 +93,9 @@ def test(args): misc.imsave(dcrf_path, decoded_crf) print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path)) - if torch.cuda.is_available(): - model.cuda(0) - images = Variable(img.cuda(0), volatile=True) - else: - images = Variable(img, volatile=True) - - pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0) + pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0).astype(np.uint8) + if model_name == 'pspnet': + pred = misc.imresize(pred, orig_size, 'nearest') # uint8 with RGB mode decoded = loader.decode_segmap(pred) print('Classes found: ', np.unique(pred)) misc.imsave(args.out_path, decoded) diff --git a/train.py b/train.py index ae33c670..fc8f077c 100644 --- a/train.py +++ b/train.py @@ -86,7 +86,16 @@ def train(args): optimizer.zero_grad() outputs = model(images) - loss = loss_fn(input=outputs, target=labels) + if args.arch == 'pspnet': + aux_cls, final_cls = outputs + + aux_loss = loss_fn(input=aux_cls, target=labels) + final_loss = loss_fn(input=final_cls, target=labels) + + LAMBDA1, LAMBDA2 = 0.4, 1.0 + loss = LAMBDA1 * aux_loss + LAMBDA2 * final_loss + else: + loss = loss_fn(input=outputs, target=labels) loss.backward() optimizer.step() diff --git a/validate.py b/validate.py index 2a16821a..4270cec1 100644 --- a/validate.py +++ b/validate.py @@ -49,8 +49,27 @@ def validate(args): images = Variable(images.cuda(), volatile=True) #labels = Variable(labels.cuda(), volatile=True) - outputs = model(images) - pred = outputs.data.max(1)[1].cpu().numpy() + if model_name == 'pspnet': + outputs = model(images)[-1] + if args.include_flip_mode: + outputs = outputs.data.cpu().numpy() + flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1]) + flipped_images = Variable(torch.from_numpy( flipped_images ).float().cuda(), volatile=True) + outputs_flipped = model( flipped_images )[-1].data.cpu().numpy() + outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0 + else: + outputs = model(images) + if args.include_flip_mode: + outputs = outputs.data.cpu().numpy() + flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1]) + flipped_images = Variable(torch.from_numpy( flipped_images ).float().cuda(), volatile=True) + outputs_flipped = model( flipped_images ).data.cpu().numpy() + outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0 + + if args.include_flip_mode: + pred = np.argmax(outputs, axis=1).astype(np.uint8) + else: + pred = outputs.data.max(1)[1].cpu().numpy().astype(np.uint8) #gt = labels.data.cpu().numpy() gt = labels.numpy() From d93db6a452ee10be3caaebcca501e3d58cd45ad7 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Thu, 19 Apr 2018 14:12:17 +0800 Subject: [PATCH 08/13] Fix label resize error --- ptsemseg/loss.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ptsemseg/loss.py b/ptsemseg/loss.py index dc32c284..9942366f 100644 --- a/ptsemseg/loss.py +++ b/ptsemseg/loss.py @@ -9,14 +9,16 @@ def cross_entropy2d(input, target, weight=None, size_average=True): n, c, h, w = input.size() - nt, ct, ht, wt = target.size() + nt, ht, wt = target.size() if h != ht or w != wt: # inconsistent size between input and target lbl = target.data.cpu().numpy() lbl = lbl.astype(float) - lbl = m.imresize(lbl, (h, w), 'nearest', mode='F') - lbl = lbl.astype(int) - target = Variable(torch.from_numpy(lbl).long().cuda()) + lbl_resized = np.zeros((n, h, w)) + for i in range(nt): + lbl_resized[i,:,:] = m.imresize(lbl[i,:,:], (h, w), 'nearest', mode='F') + lbl_resized = lbl_resized.astype(int) + target = Variable(torch.from_numpy(lbl_resized).long().cuda()) log_p = F.log_softmax(input, dim=1) log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) From 28c9125269a7ec50a997e5f8dda6de53442085e7 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Fri, 20 Apr 2018 19:44:10 +0800 Subject: [PATCH 09/13] Add arg img_norm, and input type for m.imresize is uint8 with RGB mode --- ptsemseg/loader/ade20k_loader.py | 20 ++++++++++++------- ptsemseg/loader/camvid_loader.py | 11 +++++++--- .../mit_sceneparsing_benchmark_loader.py | 20 ++++++++++++------- ptsemseg/loader/nyuv2_loader.py | 16 ++++++++------- ptsemseg/loader/pascal_voc_loader.py | 16 ++++++++------- ptsemseg/loader/sunrgbd_loader.py | 16 ++++++++------- 6 files changed, 61 insertions(+), 38 deletions(-) diff --git a/ptsemseg/loader/ade20k_loader.py b/ptsemseg/loader/ade20k_loader.py index 45d51a3a..d43a5ef3 100644 --- a/ptsemseg/loader/ade20k_loader.py +++ b/ptsemseg/loader/ade20k_loader.py @@ -11,10 +11,12 @@ from ptsemseg.utils import recursive_glob class ADE20KLoader(data.Dataset): - def __init__(self, root, split="training", is_transform=False, img_size=512): + def __init__(self, root, split="training", is_transform=False, img_size=512, augmentations=None, img_norm=True): self.root = root self.split = split self.is_transform = is_transform + self.augmentations = augmentations + self.img_norm = img_norm self.n_classes = 150 self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.mean = np.array([104.00699, 116.66877, 122.67892]) @@ -37,6 +39,9 @@ def __getitem__(self, index): lbl = m.imread(lbl_path) lbl = np.array(lbl, dtype=np.int32) + if self.augmentations is not None: + img, lbl = self.augmentations(img, lbl) + if self.is_transform: img, lbl = self.transform(img, lbl) @@ -44,14 +49,15 @@ def __getitem__(self, index): def transform(self, img, lbl): - img = img[:, :, ::-1] + img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode + img = img[:, :, ::-1] # RGB -> BGR img = img.astype(np.float64) img -= self.mean - img = m.imresize(img, (self.img_size[0], self.img_size[1])) - # Resize scales images from 0 to 255, thus we need - # to divide by 255.0 - img = img.astype(float) / 255.0 - # NHWC -> NCWH + if self.img_norm: + # Resize scales images from 0 to 255, thus we need + # to divide by 255.0 + img = img.astype(float) / 255.0 + # NHWC -> NCHW img = img.transpose(2, 0, 1) lbl = self.encode_segmap(lbl) diff --git a/ptsemseg/loader/camvid_loader.py b/ptsemseg/loader/camvid_loader.py index fc2df622..76f3398f 100644 --- a/ptsemseg/loader/camvid_loader.py +++ b/ptsemseg/loader/camvid_loader.py @@ -11,12 +11,13 @@ class camvidLoader(data.Dataset): def __init__(self, root, split="train", - is_transform=False, img_size=None, augmentations=None): + is_transform=False, img_size=None, augmentations=None, img_norm=True): self.root = root self.split = split self.img_size = [360, 480] self.is_transform = is_transform self.augmentations = augmentations + self.img_norm = img_norm self.mean = np.array([104.00699, 116.66877, 122.67892]) self.n_classes = 12 self.files = collections.defaultdict(list) @@ -48,10 +49,14 @@ def __getitem__(self, index): return img, lbl def transform(self, img, lbl): - img = img[:, :, ::-1] + img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode + img = img[:, :, ::-1] # RGB -> BGR img = img.astype(np.float64) img -= self.mean - img = img.astype(float) / 255.0 + if self.img_norm: + # Resize scales images from 0 to 255, thus we need + # to divide by 255.0 + img = img.astype(float) / 255.0 # NHWC -> NCHW img = img.transpose(2, 0, 1) diff --git a/ptsemseg/loader/mit_sceneparsing_benchmark_loader.py b/ptsemseg/loader/mit_sceneparsing_benchmark_loader.py index b89865c6..db936b07 100644 --- a/ptsemseg/loader/mit_sceneparsing_benchmark_loader.py +++ b/ptsemseg/loader/mit_sceneparsing_benchmark_loader.py @@ -23,7 +23,7 @@ class MITSceneParsingBenchmarkLoader(data.Dataset): https://github.com/CSAILVision/placeschallenge/tree/master/sceneparsing """ - def __init__(self, root, split="training", is_transform=False, img_size=512): + def __init__(self, root, split="training", is_transform=False, img_size=512, augmentations=None, img_norm=True): """__init__ :param root: @@ -34,6 +34,8 @@ def __init__(self, root, split="training", is_transform=False, img_size=512): self.root = root self.split = split self.is_transform = is_transform + self.augmentations = augmentations + self.img_norm = img_norm self.n_classes = 151 # 0 is reserved for "other" self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.mean = np.array([104.00699, 116.66877, 122.67892]) @@ -67,6 +69,9 @@ def __getitem__(self, index): lbl = m.imread(lbl_path) lbl = np.array(lbl, dtype=np.uint8) + if self.augmentations is not None: + img, lbl = self.augmentations(img, lbl) + if self.is_transform: img, lbl = self.transform(img, lbl) @@ -78,14 +83,15 @@ def transform(self, img, lbl): :param img: :param lbl: """ - img = img[:, :, ::-1] + img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode + img = img[:, :, ::-1] # RGB -> BGR img = img.astype(np.float64) img -= self.mean - img = m.imresize(img, (self.img_size[0], self.img_size[1])) - # Resize scales images from 0 to 255, thus we need - # to divide by 255.0 - img = img.astype(float) / 255.0 - # NHWC -> NCWH + if self.img_norm: + # Resize scales images from 0 to 255, thus we need + # to divide by 255.0 + img = img.astype(float) / 255.0 + # NHWC -> NCHW img = img.transpose(2, 0, 1) classes = np.unique(lbl) diff --git a/ptsemseg/loader/nyuv2_loader.py b/ptsemseg/loader/nyuv2_loader.py index 83c9835c..5a2e061b 100644 --- a/ptsemseg/loader/nyuv2_loader.py +++ b/ptsemseg/loader/nyuv2_loader.py @@ -25,11 +25,12 @@ class NYUv2Loader(data.Dataset): """ - def __init__(self, root, split="training", is_transform=False, img_size=(480,640), augmentations=None): + def __init__(self, root, split="training", is_transform=False, img_size=(480,640), augmentations=None, img_norm=True): self.root = root self.is_transform = is_transform self.n_classes = 14 self.augmentations = augmentations + self.img_norm = img_norm self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.mean = np.array([104.00699, 116.66877, 122.67892]) self.files = collections.defaultdict(list) @@ -71,14 +72,15 @@ def __getitem__(self, index): def transform(self, img, lbl): - img = img[:, :, ::-1] + img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode + img = img[:, :, ::-1] # RGB -> BGR img = img.astype(np.float64) img -= self.mean - img = m.imresize(img, (self.img_size[0], self.img_size[1])) - # Resize scales images from 0 to 255, thus we need - # to divide by 255.0 - img = img.astype(float) / 255.0 - # NHWC -> NCWH + if self.img_norm: + # Resize scales images from 0 to 255, thus we need + # to divide by 255.0 + img = img.astype(float) / 255.0 + # NHWC -> NCHW img = img.transpose(2, 0, 1) classes = np.unique(lbl) diff --git a/ptsemseg/loader/pascal_voc_loader.py b/ptsemseg/loader/pascal_voc_loader.py index cb933229..ae773685 100644 --- a/ptsemseg/loader/pascal_voc_loader.py +++ b/ptsemseg/loader/pascal_voc_loader.py @@ -51,11 +51,12 @@ class pascalVOCLoader(data.Dataset): rather than VOC 2011) - 904 images """ def __init__(self, root, split='train_aug', is_transform=False, - img_size=512, augmentations=None): + img_size=512, augmentations=None, img_norm=True): self.root = os.path.expanduser(root) self.split = split self.is_transform = is_transform self.augmentations = augmentations + self.img_norm = img_norm self.n_classes = 21 self.mean = np.array([104.00699, 116.66877, 122.67892]) self.files = collections.defaultdict(list) @@ -89,14 +90,15 @@ def __getitem__(self, index): def transform(self, img, lbl): - img = img[:, :, ::-1] + img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode + img = img[:, :, ::-1] # RGB -> BGR img = img.astype(np.float64) img -= self.mean - img = m.imresize(img, (self.img_size[0], self.img_size[1])) - # Resize scales images from 0 to 255, thus we need - # to divide by 255.0 - img = img.astype(float) / 255.0 - # NHWC -> NCWH + if self.img_norm: + # Resize scales images from 0 to 255, thus we need + # to divide by 255.0 + img = img.astype(float) / 255.0 + # NHWC -> NCHW img = img.transpose(2, 0, 1) lbl[lbl==255] = 0 diff --git a/ptsemseg/loader/sunrgbd_loader.py b/ptsemseg/loader/sunrgbd_loader.py index fc0deabf..85fd2f07 100644 --- a/ptsemseg/loader/sunrgbd_loader.py +++ b/ptsemseg/loader/sunrgbd_loader.py @@ -25,11 +25,12 @@ class SUNRGBDLoader(data.Dataset): test and train labels source: https://github.com/ankurhanda/sunrgbd-meta-data/raw/master/sunrgbd_train_test_labels.tar.gz """ - def __init__(self, root, split="training", is_transform=False, img_size=(480, 640), augmentations=None): + def __init__(self, root, split="training", is_transform=False, img_size=(480, 640), augmentations=None, img_norm=True): self.root = root self.is_transform = is_transform self.n_classes = 38 self.augmentations = augmentations + self.img_norm = img_norm self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.mean = np.array([104.00699, 116.66877, 122.67892]) self.files = collections.defaultdict(list) @@ -78,14 +79,15 @@ def __getitem__(self, index): def transform(self, img, lbl): - img = img[:, :, ::-1] + img = m.imresize(img, (self.img_size[0], self.img_size[1])) # uint8 with RGB mode + img = img[:, :, ::-1] # RGB -> BGR img = img.astype(np.float64) img -= self.mean - img = m.imresize(img, (self.img_size[0], self.img_size[1])) - # Resize scales images from 0 to 255, thus we need - # to divide by 255.0 - img = img.astype(float) / 255.0 - # NHWC -> NCWH + if self.img_norm: + # Resize scales images from 0 to 255, thus we need + # to divide by 255.0 + img = img.astype(float) / 255.0 + # NHWC -> NCHW img = img.transpose(2, 0, 1) classes = np.unique(lbl) From 8ff983460f73b4baee877eb0eae0b5573d7fdd0e Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Fri, 20 Apr 2018 19:46:38 +0800 Subject: [PATCH 10/13] Fix upsampling for input and target in cross_entropy2d, and add loss func for PSPNet and ICNet --- ptsemseg/loss.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/ptsemseg/loss.py b/ptsemseg/loss.py index 9942366f..5ae10d62 100644 --- a/ptsemseg/loss.py +++ b/ptsemseg/loss.py @@ -1,28 +1,26 @@ import torch import numpy as np -import scipy.misc as m import torch.nn as nn import torch.nn.functional as F -from torch.autograd import Variable - def cross_entropy2d(input, target, weight=None, size_average=True): n, c, h, w = input.size() nt, ht, wt = target.size() - if h != ht or w != wt: # inconsistent size between input and target - lbl = target.data.cpu().numpy() - lbl = lbl.astype(float) - lbl_resized = np.zeros((n, h, w)) - for i in range(nt): - lbl_resized[i,:,:] = m.imresize(lbl[i,:,:], (h, w), 'nearest', mode='F') - lbl_resized = lbl_resized.astype(int) - target = Variable(torch.from_numpy(lbl_resized).long().cuda()) + # Handle inconsistent size between input and target + if h > ht and w > wt: # upsample labels + target = target.unsequeeze(1) + target = F.upsample(target, size=(h, w), mode='nearest') + target = target.sequeeze(1) + elif h < ht and w < wt: # upsample images + input = F.upsample(input, size=(ht, wt), mode='bilinear') + elif h != ht and w != wt: + raise Exception("Only support upsampling") log_p = F.log_softmax(input, dim=1) log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) - log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] + log_p = log_p[target.view(-1, 1).repeat(1, c) >= 0] log_p = log_p.view(-1, c) mask = target >= 0 @@ -62,3 +60,17 @@ def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True) weight=weight, size_average=size_average) return loss / float(batch_size) + + +def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None): + # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16] + if scale_weight == None: # scale_weight: torch tensor type + n_inp = len(input) + scale = 0.4 + scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp)) + + loss = 0.0 + for i, inp in enumerate(input): + loss = loss + scale_weight[i] * cross_entropy2d(input=inp, target=target, weight=weight, size_average=size_average) + + return loss From 8ea13b87b667670df52c1d2274329803a862ad51 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Fri, 20 Apr 2018 19:48:21 +0800 Subject: [PATCH 11/13] Add pspnet loss and model outputs adjusted by mode --- ptsemseg/models/pspnet.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ptsemseg/models/pspnet.py b/ptsemseg/models/pspnet.py index 88e671bc..be28ae3a 100644 --- a/ptsemseg/models/pspnet.py +++ b/ptsemseg/models/pspnet.py @@ -7,6 +7,7 @@ from ptsemseg import caffe_pb2 from ptsemseg.models.utils import * +from ptsemseg.loss import * pspnet_specs = { 'pascalvoc': @@ -87,6 +88,9 @@ def __init__(self, self.convbnrelu4_aux = conv2DBatchNormRelu(in_channels=1024, k_size=3, n_filters=256, padding=1, stride=1, bias=False) self.aux_cls = nn.Conv2d(256, self.n_classes, 1, 1, 0) + # Define auxiliary loss function + self.loss = multi_scale_cross_entropy2d + def forward(self, x): inp_shape = x.shape[2:] @@ -117,7 +121,11 @@ def forward(self, x): x = self.classification(x) x = F.upsample(x, size=inp_shape, mode='bilinear') - return x_aux, x + + if self.training: + return x_aux, x + else: # eval mode + return x def load_pretrained_model(self, model_path): """ @@ -313,9 +321,9 @@ def tile_predict(self, imgs, include_flip_mode=True): if include_flip_mode: flp = flp.cuda() - psub1 = F.softmax(self.forward(inp)[-1], dim=1).data.cpu().numpy() + psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy() if include_flip_mode: - psub2 = F.softmax(self.forward(flp)[-1], dim=1).data.cpu().numpy() + psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy() psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0 else: psub = psub1 @@ -363,7 +371,7 @@ def tile_predict(self, imgs, include_flip_mode=True): img = img.unsqueeze(0) out = psp.tile_predict(img) - pred = np.argmax(out, axis=1).astype(np.uint8)[0] + pred = np.argmax(out, axis=1)[0] decoded = dst.decode_segmap(pred) m.imsave('cityscapes_sttutgart_tiled.png', decoded) #m.imsave('cityscapes_sttutgart_tiled.png', pred) From 38745d46135d4e74ac4882fa2630bb843690633a Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Fri, 20 Apr 2018 19:52:29 +0800 Subject: [PATCH 12/13] Make scripts clean --- test.py | 14 ++++++-------- train.py | 11 +---------- validate.py | 39 +++++++++++++++++---------------------- 3 files changed, 24 insertions(+), 40 deletions(-) diff --git a/test.py b/test.py index 85ca1acd..10a1032c 100644 --- a/test.py +++ b/test.py @@ -39,7 +39,7 @@ def test(args): resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp='bicubic') orig_size = img.shape[:-1] - if model_name == 'pspnet': + if model_name in ['pspnet', 'icnet', 'icnetBN']: img = misc.imresize(img, (orig_size[0]//2*2+1, orig_size[1]//2*2+1)) # uint8 with RGB mode, resize width and height which are odd numbers else: img = misc.imresize(img, (loader.img_size[0], loader.img_size[1])) @@ -65,10 +65,7 @@ def test(args): else: images = Variable(img, volatile=True) - if model_name == 'pspnet': - outputs = model(images)[-1] - else: - outputs = model(images) + outputs = model(images) #outputs = F.softmax(outputs, dim=1) if args.dcrf: @@ -93,9 +90,10 @@ def test(args): misc.imsave(dcrf_path, decoded_crf) print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path)) - pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0).astype(np.uint8) - if model_name == 'pspnet': - pred = misc.imresize(pred, orig_size, 'nearest') # uint8 with RGB mode + pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0) + if model_name in ['pspnet', 'icnet', 'icnetBN']: + pred = pred.astype(np.float32) + pred = misc.imresize(pred, orig_size, 'nearest', mode='F') # float32 with F mode, resize back to orig_size decoded = loader.decode_segmap(pred) print('Classes found: ', np.unique(pred)) misc.imsave(args.out_path, decoded) diff --git a/train.py b/train.py index fc8f077c..ae33c670 100644 --- a/train.py +++ b/train.py @@ -86,16 +86,7 @@ def train(args): optimizer.zero_grad() outputs = model(images) - if args.arch == 'pspnet': - aux_cls, final_cls = outputs - - aux_loss = loss_fn(input=aux_cls, target=labels) - final_loss = loss_fn(input=final_cls, target=labels) - - LAMBDA1, LAMBDA2 = 0.4, 1.0 - loss = LAMBDA1 * aux_loss + LAMBDA2 * final_loss - else: - loss = loss_fn(input=outputs, target=labels) + loss = loss_fn(input=outputs, target=labels) loss.backward() optimizer.step() diff --git a/validate.py b/validate.py index 4270cec1..a415c633 100644 --- a/validate.py +++ b/validate.py @@ -49,27 +49,22 @@ def validate(args): images = Variable(images.cuda(), volatile=True) #labels = Variable(labels.cuda(), volatile=True) - if model_name == 'pspnet': - outputs = model(images)[-1] - if args.include_flip_mode: - outputs = outputs.data.cpu().numpy() - flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1]) - flipped_images = Variable(torch.from_numpy( flipped_images ).float().cuda(), volatile=True) - outputs_flipped = model( flipped_images )[-1].data.cpu().numpy() - outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0 - else: + if args.eval_flip: outputs = model(images) - if args.include_flip_mode: - outputs = outputs.data.cpu().numpy() - flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1]) - flipped_images = Variable(torch.from_numpy( flipped_images ).float().cuda(), volatile=True) - outputs_flipped = model( flipped_images ).data.cpu().numpy() - outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0 - - if args.include_flip_mode: - pred = np.argmax(outputs, axis=1).astype(np.uint8) + + # Flip images in numpy (not support in tensor) + outputs = outputs.data.cpu().numpy() + flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1]) + flipped_images = Variable(torch.from_numpy( flipped_images ).float().cuda(), volatile=True) + outputs_flipped = model( flipped_images ) + outputs_flipped = outputs_flipped.data.cpu().numpy() + outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0 + + pred = np.argmax(outputs, axis=1) else: - pred = outputs.data.max(1)[1].cpu().numpy().astype(np.uint8) + outputs = model(images) + pred = outputs.data.max(1)[1].cpu().numpy() + #gt = labels.data.cpu().numpy() gt = labels.numpy() @@ -103,11 +98,11 @@ def validate(args): help='Disable input image scales normalization [0, 1] | True by default') parser.set_defaults(img_norm=True) - parser.add_argument('--include_flip_mode', dest='include_flip_mode', action='store_true', + parser.add_argument('--eval_flip', dest='eval_flip', action='store_true', help='Enable evaluation with flipped image | True by default') - parser.add_argument('--no-include_flip_mode', dest='include_flip_mode', action='store_false', + parser.add_argument('--no-eval_flip', dest='eval_flip', action='store_false', help='Disable evaluation with flipped image | True by default') - parser.set_defaults(include_flip_mode=True) + parser.set_defaults(eval_flip=True) parser.add_argument('--batch_size', nargs='?', type=int, default=1, help='Batch Size') From 7d2a9db53ef30568d0b6ea73fa96813dc7d3c968 Mon Sep 17 00:00:00 2001 From: adam9500370 Date: Fri, 20 Apr 2018 19:55:00 +0800 Subject: [PATCH 13/13] Add ICNet --- ptsemseg/models/__init__.py | 8 + ptsemseg/models/icnet.py | 415 ++++++++++++++++++++++++++++++++++++ ptsemseg/models/utils.py | 170 ++++++++++++--- 3 files changed, 560 insertions(+), 33 deletions(-) create mode 100644 ptsemseg/models/icnet.py diff --git a/ptsemseg/models/__init__.py b/ptsemseg/models/__init__.py index f78ddf93..59f2840e 100644 --- a/ptsemseg/models/__init__.py +++ b/ptsemseg/models/__init__.py @@ -4,6 +4,7 @@ from ptsemseg.models.segnet import * from ptsemseg.models.unet import * from ptsemseg.models.pspnet import * +from ptsemseg.models.icnet import * from ptsemseg.models.linknet import * from ptsemseg.models.frrn import * @@ -34,6 +35,11 @@ def get_model(name, n_classes, version=None): elif name == 'pspnet': model = model(n_classes=n_classes, version=version) + elif name == 'icnet': + model = model(n_classes=n_classes, with_bn=False, version=version) + elif name == 'icnetBN': + model = model(n_classes=n_classes, with_bn=True, version=version) + else: model = model(n_classes=n_classes) @@ -48,6 +54,8 @@ def _get_model_instance(name): 'unet': unet, 'segnet': segnet, 'pspnet': pspnet, + 'icnet': icnet, + 'icnetBN': icnet, 'linknet': linknet, 'frrnA': frrn, 'frrnB': frrn, diff --git a/ptsemseg/models/icnet.py b/ptsemseg/models/icnet.py new file mode 100644 index 00000000..f2d4d7b4 --- /dev/null +++ b/ptsemseg/models/icnet.py @@ -0,0 +1,415 @@ +import torch +import numpy as np +import torch.nn as nn + +from math import ceil +from torch.autograd import Variable + +from ptsemseg import caffe_pb2 +from ptsemseg.models.utils import * +from ptsemseg.loss import * + +icnet_specs = { + 'cityscapes': + { + 'n_classes': 19, + 'input_size': (1025, 2049), + 'block_config': [3, 4, 6, 3], + }, +} + +class icnet(nn.Module): + + """ + Image Cascade Network + URL: https://arxiv.org/abs/1704.08545 + + References: + 1) Original Author's code: https://github.com/hszhao/ICNet + 2) Chainer implementation by @mitmul: https://github.com/mitmul/chainer-pspnet + 3) TensorFlow implementation by @hellochick: https://github.com/hellochick/ICNet-tensorflow + + """ + + def __init__(self, + n_classes=19, + block_config=[3, 4, 6, 3], + input_size=(1025, 2049), + version=None, + with_bn=True): + + super(icnet, self).__init__() + + bias = not with_bn + + self.block_config = icnet_specs[version]['block_config'] if version is not None else block_config + self.n_classes = icnet_specs[version]['n_classes'] if version is not None else n_classes + self.input_size = icnet_specs[version]['input_size'] if version is not None else input_size + + # Encoder + self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=32, + padding=1, stride=2, bias=bias, with_bn=with_bn) + self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=32, k_size=3, n_filters=32, + padding=1, stride=1, bias=bias, with_bn=with_bn) + self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=32, k_size=3, n_filters=64, + padding=1, stride=1, bias=bias, with_bn=with_bn) + + # Vanilla Residual Blocks + self.res_block2 = residualBlockPSP(self.block_config[0], 64, 32, 128, 1, 1, with_bn=with_bn) + self.res_block3_conv = residualBlockPSP(self.block_config[1], 128, 64, 256, 2, 1, include_range='conv', with_bn=with_bn) + self.res_block3_identity = residualBlockPSP(self.block_config[1], 128, 64, 256, 2, 1, include_range='identity', with_bn=with_bn) + + # Dilated Residual Blocks + self.res_block4 = residualBlockPSP(self.block_config[2], 256, 128, 512, 1, 2, with_bn=with_bn) + self.res_block5 = residualBlockPSP(self.block_config[3], 512, 256, 1024, 1, 4, with_bn=with_bn) + + # Pyramid Pooling Module + self.pyramid_pooling = pyramidPooling(1024, [6, 3, 2, 1], model_name='icnet', fusion_mode='sum', with_bn=with_bn) + + # Final conv layer with kernel 1 in sub4 branch + self.conv5_4_k1 = conv2DBatchNormRelu(in_channels=1024, k_size=1, n_filters=256, + padding=0, stride=1, bias=bias, with_bn=with_bn) + + # High-resolution (sub1) branch + self.convbnrelu1_sub1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=32, + padding=1, stride=2, bias=bias, with_bn=with_bn) + self.convbnrelu2_sub1 = conv2DBatchNormRelu(in_channels=32, k_size=3, n_filters=32, + padding=1, stride=2, bias=bias, with_bn=with_bn) + self.convbnrelu3_sub1 = conv2DBatchNormRelu(in_channels=32, k_size=3, n_filters=64, + padding=1, stride=2, bias=bias, with_bn=with_bn) + self.classification = nn.Conv2d(128, self.n_classes, 1, 1, 0) + + # Cascade Feature Fusion Units + self.cff_sub24 = cascadeFeatureFusion(self.n_classes, 256, 256, 128, with_bn=with_bn) + self.cff_sub12 = cascadeFeatureFusion(self.n_classes, 128, 64, 128, with_bn=with_bn) + + # Define auxiliary loss function + self.loss = multi_scale_cross_entropy2d + + def forward(self, x): + h, w = x.shape[2:] + + # H, W -> H/2, W/2 + x_sub2 = interp(x, output_size=get_interp_size(x, s_factor=2)) + + # H/2, W/2 -> H/4, W/4 + x_sub2 = self.convbnrelu1_1(x_sub2) + x_sub2 = self.convbnrelu1_2(x_sub2) + x_sub2 = self.convbnrelu1_3(x_sub2) + + # H/4, W/4 -> H/8, W/8 + x_sub2 = F.max_pool2d(x_sub2, 3, 2, 1) + + # H/8, W/8 -> H/16, W/16 + x_sub2 = self.res_block2(x_sub2) + x_sub2 = self.res_block3_conv(x_sub2) + # H/16, W/16 -> H/32, W/32 + x_sub4 = interp(x_sub2, output_size=get_interp_size(x_sub2, s_factor=2)) + x_sub4 = self.res_block3_identity(x_sub4) + + x_sub4 = self.res_block4(x_sub4) + x_sub4 = self.res_block5(x_sub4) + + x_sub4 = self.pyramid_pooling(x_sub4) + x_sub4 = self.conv5_4_k1(x_sub4) + + x_sub1 = self.convbnrelu1_sub1(x) + x_sub1 = self.convbnrelu2_sub1(x_sub1) + x_sub1 = self.convbnrelu3_sub1(x_sub1) + + x_sub24, sub4_cls = self.cff_sub24(x_sub4, x_sub2) + x_sub12, sub24_cls = self.cff_sub12(x_sub24, x_sub1) + + x_sub12 = F.upsample(x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode='bilinear') + sub124_cls = self.classification(x_sub12) + + if self.training: + return sub4_cls, sub24_cls, sub124_cls + else: # eval mode + sub124_cls = F.upsample(sub124_cls, size=get_interp_size(sub124_cls, z_factor=4), mode='bilinear') # Test only + return sub124_cls + + def load_pretrained_model(self, model_path): + """ + Load weights from caffemodel w/o caffe dependency + and plug them in corresponding modules + """ + # My eyes and my heart both hurt when writing this method + + # Only care about layer_types that have trainable parameters + ltypes = ['BNData', 'ConvolutionData', 'HoleConvolutionData', 'Convolution'] # Convolution type for conv3_sub1_proj + + def _get_layer_params(layer, ltype): + + if ltype == 'BNData': + gamma = np.array(layer.blobs[0].data) + beta = np.array(layer.blobs[1].data) + mean = np.array(layer.blobs[2].data) + var = np.array(layer.blobs[3].data) + return [mean, var, gamma, beta] + + elif ltype in ['ConvolutionData', 'HoleConvolutionData', 'Convolution']: + is_bias = layer.convolution_param.bias_term + weights = np.array(layer.blobs[0].data) + bias = [] + if is_bias: + bias = np.array(layer.blobs[1].data) + return [weights, bias] + + elif ltype == 'InnerProduct': + raise Exception("Fully connected layers {}, not supported".format(ltype)) + + else: + raise Exception("Unkown layer type {}".format(ltype)) + + + net = caffe_pb2.NetParameter() + with open(model_path, 'rb') as model_file: + net.MergeFromString(model_file.read()) + + # dict formatted as -> key: :: value: + layer_types = {} + # dict formatted as -> key: :: value:[] + layer_params = {} + + for l in net.layer: + lname = l.name + ltype = l.type + lbottom = l.bottom + ltop = l.top + if ltype in ltypes: + print("Processing layer {} | {}, {}".format(lname, lbottom, ltop)) + layer_types[lname] = ltype + layer_params[lname] = _get_layer_params(l, ltype) + #if len(l.blobs) > 0: + # print(lname, ltype, lbottom, ltop, len(l.blobs)) + + # Set affine=False for all batchnorm modules + def _no_affine_bn(module=None): + if isinstance(module, nn.BatchNorm2d): + module.affine = False + + if len([m for m in module.children()]) > 0: + for child in module.children(): + _no_affine_bn(child) + + #_no_affine_bn(self) + + + def _transfer_conv(layer_name, module): + weights, bias = layer_params[layer_name] + w_shape = np.array(module.weight.size()) + + print("CONV {}: Original {} and trans weights {}".format(layer_name, + w_shape, + weights.shape)) + + module.weight.data.copy_(torch.from_numpy(weights).view_as(module.weight)) + + if len(bias) != 0: + b_shape = np.array(module.bias.size()) + print("CONV {}: Original {} and trans bias {}".format(layer_name, + b_shape, + bias.shape)) + module.bias.data.copy_(torch.from_numpy(bias).view_as(module.bias)) + + + def _transfer_bn(conv_layer_name, bn_module): + mean, var, gamma, beta = layer_params[conv_layer_name+'/bn'] + print("BN {}: Original {} and trans weights {}".format(conv_layer_name, + bn_module.running_mean.size(), + mean.shape)) + bn_module.running_mean.copy_(torch.from_numpy(mean).view_as(bn_module.running_mean)) + bn_module.running_var.copy_(torch.from_numpy(var).view_as(bn_module.running_var)) + bn_module.weight.data.copy_(torch.from_numpy(gamma).view_as(bn_module.weight)) + bn_module.bias.data.copy_(torch.from_numpy(beta).view_as(bn_module.bias)) + + + def _transfer_conv_bn(conv_layer_name, mother_module): + conv_module = mother_module[0] + _transfer_conv(conv_layer_name, conv_module) + + if conv_layer_name+'/bn' in layer_params.keys(): + bn_module = mother_module[1] + _transfer_bn(conv_layer_name, bn_module) + + + def _transfer_residual(block_name, block): + block_module, n_layers = block[0], block[1] + prefix = block_name[:5] + + if ('bottleneck' in block_name) or ('identity' not in block_name): # Conv block + bottleneck = block_module.layers[0] + bottleneck_conv_bn_dic = {prefix + '_1_1x1_reduce': bottleneck.cbr1.cbr_unit, + prefix + '_1_3x3': bottleneck.cbr2.cbr_unit, + prefix + '_1_1x1_proj': bottleneck.cb4.cb_unit, + prefix + '_1_1x1_increase': bottleneck.cb3.cb_unit,} + + for k, v in bottleneck_conv_bn_dic.items(): + _transfer_conv_bn(k, v) + + if ('identity' in block_name) or ('bottleneck' not in block_name): # Identity blocks + base_idx = 2 if 'identity' in block_name else 1 + + for layer_idx in range(2, n_layers+1): + residual_layer = block_module.layers[layer_idx-base_idx] + residual_conv_bn_dic = {'_'.join(map(str, [prefix, layer_idx, '1x1_reduce'])): residual_layer.cbr1.cbr_unit, + '_'.join(map(str, [prefix, layer_idx, '3x3'])): residual_layer.cbr2.cbr_unit, + '_'.join(map(str, [prefix, layer_idx, '1x1_increase'])): residual_layer.cb3.cb_unit,} + + for k, v in residual_conv_bn_dic.items(): + _transfer_conv_bn(k, v) + + + convbn_layer_mapping = {'conv1_1_3x3_s2': self.convbnrelu1_1.cbr_unit, + 'conv1_2_3x3': self.convbnrelu1_2.cbr_unit, + 'conv1_3_3x3': self.convbnrelu1_3.cbr_unit, + 'conv1_sub1': self.convbnrelu1_sub1.cbr_unit, + 'conv2_sub1': self.convbnrelu2_sub1.cbr_unit, + 'conv3_sub1': self.convbnrelu3_sub1.cbr_unit, + #'conv5_3_pool6_conv': self.pyramid_pooling.paths[0].cbr_unit, + #'conv5_3_pool3_conv': self.pyramid_pooling.paths[1].cbr_unit, + #'conv5_3_pool2_conv': self.pyramid_pooling.paths[2].cbr_unit, + #'conv5_3_pool1_conv': self.pyramid_pooling.paths[3].cbr_unit, + 'conv5_4_k1': self.conv5_4_k1.cbr_unit, + 'conv_sub4': self.cff_sub24.low_dilated_conv_bn.cb_unit, + 'conv3_1_sub2_proj': self.cff_sub24.high_proj_conv_bn.cb_unit, + 'conv_sub2': self.cff_sub12.low_dilated_conv_bn.cb_unit, + 'conv3_sub1_proj': self.cff_sub12.high_proj_conv_bn.cb_unit,} + + residual_layers = {'conv2': [self.res_block2, self.block_config[0]], + 'conv3_bottleneck': [self.res_block3_conv, self.block_config[1]], + 'conv3_identity': [self.res_block3_identity, self.block_config[1]], + 'conv4': [self.res_block4, self.block_config[2]], + 'conv5': [self.res_block5, self.block_config[3]],} + + # Transfer weights for all non-residual conv+bn layers + for k, v in convbn_layer_mapping.items(): + _transfer_conv_bn(k, v) + + # Transfer weights for final non-bn conv layer + _transfer_conv('conv6_cls', self.classification) + _transfer_conv('conv6_sub4', self.cff_sub24.low_classifier_conv) + _transfer_conv('conv6_sub2', self.cff_sub12.low_classifier_conv) + + # Transfer weights for all residual layers + for k, v in residual_layers.items(): + _transfer_residual(k, v) + + + def tile_predict(self, imgs, include_flip_mode=True): + """ + Predict by takin overlapping tiles from the image. + + Strides are adaptively computed from the imgs shape + and input size + + :param imgs: torch.Tensor with shape [N, C, H, W] in BGR format + :param side: int with side length of model input + :param n_classes: int with number of classes in seg output. + """ + + side_x, side_y = self.input_size + n_classes = self.n_classes + n_samples, c, h, w = imgs.shape + #n = int(max(h,w) / float(side) + 1) + n_x = int(h / float(side_x) + 1) + n_y = int(w / float(side_y) + 1) + stride_x = ( h - side_x ) / float(n_x) + stride_y = ( w - side_y ) / float(n_y) + + x_ends = [[int(i*stride_x), int(i*stride_x) + side_x] for i in range(n_x+1)] + y_ends = [[int(i*stride_y), int(i*stride_y) + side_y] for i in range(n_y+1)] + + pred = np.zeros([n_samples, n_classes, h, w]) + count = np.zeros([h, w]) + + slice_count = 0 + for sx, ex in x_ends: + for sy, ey in y_ends: + slice_count += 1 + + imgs_slice = imgs[:, :, sx:ex, sy:ey] + if include_flip_mode: + imgs_slice_flip = torch.from_numpy(np.copy(imgs_slice.cpu().numpy()[:, :, :, ::-1])).float() + + is_model_on_cuda = next(self.parameters()).is_cuda + + inp = Variable(imgs_slice, volatile=True) + if include_flip_mode: + flp = Variable(imgs_slice_flip, volatile=True) + + if is_model_on_cuda: + inp = inp.cuda() + if include_flip_mode: + flp = flp.cuda() + + psub1 = F.softmax(self.forward(inp), dim=1).data.cpu().numpy() + if include_flip_mode: + psub2 = F.softmax(self.forward(flp), dim=1).data.cpu().numpy() + psub = (psub1 + psub2[:, :, :, ::-1]) / 2.0 + else: + psub = psub1 + + pred[:, :, sx:ex, sy:ey] = psub + count[sx:ex, sy:ey] += 1.0 + + score = (pred / count[None, None, ...]).astype(np.float32) + return score / np.expand_dims(score.sum(axis=1), axis=1) + + + +# For Testing Purposes only +if __name__ == '__main__': + cd = 0 + import os + from torch.autograd import Variable + import matplotlib.pyplot as plt + import scipy.misc as m + from ptsemseg.loader.cityscapes_loader import cityscapesLoader as cl + ic = icnet(version='cityscapes', with_bn=False) + + # Just need to do this one time + caffemodel_dir_path = 'PATH_TO_ICNET_DIR/evaluation/model' + ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 'icnet_cityscapes_train_30k.caffemodel')) + #ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 'icnet_cityscapes_train_30k_bnnomerge.caffemodel')) + #ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 'icnet_cityscapes_trainval_90k.caffemodel')) + #ic.load_pretrained_model(model_path=os.path.join(caffemodel_dir_path, 'icnet_cityscapes_trainval_90k_bnnomerge.caffemodel')) + + # ic.load_state_dict(torch.load('ic.pth')) + + ic.float() + ic.cuda(cd) + ic.eval() + + dataset_root_dir = 'PATH_TO_CITYSCAPES_DIR' + dst = cl(root=dataset_root_dir) + img = m.imread(os.path.join(dataset_root_dir, 'leftImg8bit/demoVideo/stuttgart_00/stuttgart_00_000000_000010_leftImg8bit.png')) + m.imsave('test_input.png', img) + orig_size = img.shape[:-1] + img = m.imresize(img, ic.input_size) # uint8 with RGB mode + img = img.transpose(2, 0, 1) + img = img.astype(np.float64) + img -= np.array([123.68, 116.779, 103.939])[:, None, None] + img = np.copy(img[::-1, :, :]) + img = torch.from_numpy(img).float() + img = img.unsqueeze(0) + + out = ic.tile_predict(img) + pred = np.argmax(out, axis=1)[0] + pred = pred.astype(np.float32) + pred = m.imresize(pred, orig_size, 'nearest', mode='F') # float32 with F mode + decoded = dst.decode_segmap(pred) + m.imsave('test_output.png', decoded) + #m.imsave('test_output.png', pred) + + checkpoints_dir_path = 'checkpoints' + if not os.path.exists(checkpoints_dir_path): + os.mkdir(checkpoints_dir_path) + ic = torch.nn.DataParallel(ic, device_ids=range(torch.cuda.device_count())) + state = {'model_state': ic.state_dict()} + torch.save(state, os.path.join(checkpoints_dir_path, "icnet_cityscapes_train_30k.pth")) + #torch.save(state, os.path.join(checkpoints_dir_path, "icnetBN_cityscapes_train_30k.pth")) + #torch.save(state, os.path.join(checkpoints_dir_path, "icnet_cityscapes_trainval_90k.pth")) + #torch.save(state, os.path.join(checkpoints_dir_path, "icnetBN_cityscapes_trainval_90k.pth")) + print("Output Shape {} \t Input Shape {}".format(out.shape, img.shape)) diff --git a/ptsemseg/models/utils.py b/ptsemseg/models/utils.py index 6883c010..4aa96384 100644 --- a/ptsemseg/models/utils.py +++ b/ptsemseg/models/utils.py @@ -2,9 +2,11 @@ import torch.nn as nn import torch.nn.functional as F +from torch.autograd import Variable + class conv2DBatchNorm(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1): + def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, with_bn=True): super(conv2DBatchNorm, self).__init__() if dilation > 1: @@ -16,8 +18,11 @@ def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, padding=padding, stride=stride, bias=bias, dilation=1) - self.cb_unit = nn.Sequential(conv_mod, - nn.BatchNorm2d(int(n_filters)),) + if with_bn: + self.cb_unit = nn.Sequential(conv_mod, + nn.BatchNorm2d(int(n_filters)),) + else: + self.cb_unit = nn.Sequential(conv_mod,) def forward(self, inputs): outputs = self.cb_unit(inputs) @@ -38,7 +43,7 @@ def forward(self, inputs): class conv2DBatchNormRelu(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1): + def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, with_bn=True): super(conv2DBatchNormRelu, self).__init__() if dilation > 1: @@ -49,9 +54,13 @@ def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=1) - self.cbr_unit = nn.Sequential(conv_mod, - nn.BatchNorm2d(int(n_filters)), - nn.ReLU(inplace=True),) + if with_bn: + self.cbr_unit = nn.Sequential(conv_mod, + nn.BatchNorm2d(int(n_filters)), + nn.ReLU(inplace=True),) + else: + self.cbr_unit = nn.Sequential(conv_mod, + nn.ReLU(inplace=True),) def forward(self, inputs): outputs = self.cbr_unit(inputs) @@ -346,46 +355,78 @@ def forward(self, x): class pyramidPooling(nn.Module): - def __init__(self, in_channels, pool_sizes): + def __init__(self, in_channels, pool_sizes, model_name='pspnet', fusion_mode='cat', with_bn=True): super(pyramidPooling, self).__init__() + bias = not with_bn + self.paths = [] for i in range(len(pool_sizes)): - self.paths.append(conv2DBatchNormRelu(in_channels, int(in_channels / len(pool_sizes)), 1, 1, 0, bias=False)) + self.paths.append(conv2DBatchNormRelu(in_channels, int(in_channels / len(pool_sizes)), 1, 1, 0, bias=bias, with_bn=with_bn)) self.path_module_list = nn.ModuleList(self.paths) self.pool_sizes = pool_sizes + self.model_name = model_name + self.fusion_mode = fusion_mode def forward(self, x): - output_slices = [x] h, w = x.shape[2:] - for module, pool_size in zip(self.path_module_list, self.pool_sizes): - out = F.avg_pool2d(x, int(h/pool_size), int(h/pool_size), 0) - out = module(out) - out = F.upsample(out, size=(h,w), mode='bilinear') - output_slices.append(out) - - return torch.cat(output_slices, dim=1) + if self.training or self.model_name != 'icnet': # general settings or pspnet + k_sizes = [] + strides = [] + for pool_size in self.pool_sizes: + k_sizes.append((int(h/pool_size), int(w/pool_size))) + strides.append((int(h/pool_size), int(w/pool_size))) + else: # eval mode and icnet: pre-trained for 1025 x 2049 + k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)] + strides = [(5, 10), (10, 20), (16, 32), (33, 65)] + + if self.fusion_mode == 'cat': # pspnet: concat (including x) + output_slices = [x] + + for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): + out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) + #out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) + if self.model_name != 'icnet': + out = module(out) + out = F.upsample(out, size=(h,w), mode='bilinear') + output_slices.append(out) + + return torch.cat(output_slices, dim=1) + else: # icnet: element-wise sum (including x) + pp_sum = x + + for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): + out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) + #out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) + if self.model_name != 'icnet': + out = module(out) + out = F.upsample(out, size=(h,w), mode='bilinear') + pp_sum = pp_sum + out + + return pp_sum class bottleNeckPSP(nn.Module): def __init__(self, in_channels, mid_channels, out_channels, - stride, dilation=1): + stride, dilation=1, with_bn=True): super(bottleNeckPSP, self).__init__() + + bias = not with_bn - self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, stride=1, padding=0, bias=False) + self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, with_bn=with_bn) if dilation > 1: self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, stride=stride, padding=dilation, - bias=False, dilation=dilation) + bias=bias, dilation=dilation, with_bn=with_bn) else: self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, stride=stride, padding=1, - bias=False, dilation=1) - self.cb3 = conv2DBatchNorm(mid_channels, out_channels, 1, stride=1, padding=0, bias=False) - self.cb4 = conv2DBatchNorm(in_channels, out_channels, 1, stride=stride, padding=0, bias=False) + bias=bias, dilation=1, with_bn=with_bn) + self.cb3 = conv2DBatchNorm(mid_channels, out_channels, 1, stride=1, padding=0, bias=bias, with_bn=with_bn) + self.cb4 = conv2DBatchNorm(in_channels, out_channels, 1, stride=stride, padding=0, bias=bias, with_bn=with_bn) def forward(self, x): conv = self.cb3(self.cbr2(self.cbr1(x))) @@ -395,19 +436,21 @@ def forward(self, x): class bottleNeckIdentifyPSP(nn.Module): - def __init__(self, in_channels, mid_channels, stride, dilation=1): + def __init__(self, in_channels, mid_channels, stride, dilation=1, with_bn=True): super(bottleNeckIdentifyPSP, self).__init__() - self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, 1, 0, bias=False) + bias = not with_bn + + self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, with_bn=with_bn) if dilation > 1: self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, stride=1, padding=dilation, - bias=False, dilation=dilation) + bias=bias, dilation=dilation, with_bn=with_bn) else: self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, - stride=1, padding=1, - bias=False, dilation=1) - self.cb3 = conv2DBatchNorm(mid_channels, in_channels, 1, stride=1, padding=0, bias=False) + stride=1, padding=1, + bias=bias, dilation=1, with_bn=with_bn) + self.cb3 = conv2DBatchNorm(mid_channels, in_channels, 1, stride=1, padding=0, bias=bias, with_bn=with_bn) def forward(self, x): residual = x @@ -417,17 +460,78 @@ def forward(self, x): class residualBlockPSP(nn.Module): - def __init__(self, n_blocks, in_channels, mid_channels, out_channels, stride, dilation=1): + def __init__(self, n_blocks, in_channels, mid_channels, out_channels, stride, dilation=1, include_range='all', with_bn=True): super(residualBlockPSP, self).__init__() if dilation > 1: stride = 1 - layers = [bottleNeckPSP(in_channels, mid_channels, out_channels, stride, dilation)] - for i in range(n_blocks-1): - layers.append(bottleNeckIdentifyPSP(out_channels, mid_channels, stride, dilation)) + # residualBlockPSP = convBlockPSP + identityBlockPSPs + layers = [] + if include_range in ['all', 'conv']: + layers.append(bottleNeckPSP(in_channels, mid_channels, out_channels, stride, dilation, with_bn=with_bn)) + if include_range in ['all', 'identity']: + for i in range(n_blocks-1): + layers.append(bottleNeckIdentifyPSP(out_channels, mid_channels, stride, dilation, with_bn=with_bn)) self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) + + +class cascadeFeatureFusion(nn.Module): + def __init__(self, n_classes, low_in_channels, high_in_channels, out_channels, with_bn=True): + super(cascadeFeatureFusion, self).__init__() + + bias = not with_bn + + self.low_dilated_conv_bn = conv2DBatchNorm(low_in_channels, out_channels, 3, stride=1, padding=2, bias=bias, dilation=2, with_bn=with_bn) + self.low_classifier_conv = nn.Conv2d(int(low_in_channels), int(n_classes), kernel_size=1, padding=0, stride=1, bias=True, dilation=1) # Train only + self.high_proj_conv_bn = conv2DBatchNorm(high_in_channels, out_channels, 1, stride=1, padding=0, bias=bias, with_bn=with_bn) + + def forward(self, x_low, x_high): + x_low_upsampled = F.upsample(x_low, size=get_interp_size(x_low, z_factor=2), mode='bilinear') + + low_cls = self.low_classifier_conv(x_low_upsampled) + + low_fm = self.low_dilated_conv_bn(x_low_upsampled) + high_fm = self.high_proj_conv_bn(x_high) + high_fused_fm = F.relu(low_fm+high_fm, inplace=True) + + return high_fused_fm, low_cls + + + +def get_interp_size(input, s_factor=1, z_factor=1): # for caffe + ori_h, ori_w = input.shape[2:] + + # shrink (s_factor >= 1) + ori_h = (ori_h - 1) / s_factor + 1 + ori_w = (ori_w - 1) / s_factor + 1 + + # zoom (z_factor >= 1) + ori_h = ori_h + (ori_h - 1) * (z_factor - 1) + ori_w = ori_w + (ori_w - 1) * (z_factor - 1) + + resize_shape = (int(ori_h), int(ori_w)) + return resize_shape + + +def interp(input, output_size, mode='bilinear'): + n, c, ih, iw = input.shape + oh, ow = output_size + + # normalize to [-1, 1] + h = torch.arange(0, oh) / (oh-1) * 2 - 1 + w = torch.arange(0, ow) / (ow-1) * 2 - 1 + + grid = torch.zeros(oh, ow, 2) + grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1) + grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1) + grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2] + grid = Variable(grid) + if input.is_cuda: + grid = grid.cuda() + + return F.grid_sample(input, grid, mode=mode)