From 0b759c3f8751e65f04a3f61200e26275a7f52ba4 Mon Sep 17 00:00:00 2001 From: David Hafner Date: Mon, 24 Jun 2019 16:26:28 +0200 Subject: [PATCH] initial commit --- README.md | 29 +++++++ environment.yml | 90 +++++++++++++++++++++ input/.placeholder | 0 monodepth_net.py | 186 ++++++++++++++++++++++++++++++++++++++++++++ output/.placeholder | 0 run.py | 71 +++++++++++++++++ utils.py | 181 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 557 insertions(+) create mode 100644 README.md create mode 100644 environment.yml create mode 100644 input/.placeholder create mode 100644 monodepth_net.py create mode 100644 output/.placeholder create mode 100644 run.py create mode 100644 utils.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..8c466c0 --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# Mixing For MonoDepth + +This code allows to compute a depth map based on a single input image. It runs a neural network that was trained by mixing several datasets as described in + +>Mixing Datasets for Single-Image Depth Estimation in Diverse Environments. +Rene Ranftl, Katrin Lasinger, Vladlen Koltun + +## Setup + +1) Download the model weights [model.pt](https://drive.google.com/open?id=1Q9q7dVFhXiNOS1djOlaUUmnJlKMenEoU) and put the file in the same folder as this README. + +2) Create and activate conda environment: + + ```shell + conda env create -f environment.yml + conda activate mixingDatasetsForMonoDepth + ``` + +## Usage + +1) Put one or more input images for monocular depth estimation in the folder `input`. + +2) Produce depth maps for the images in the `input` folder as follows: + + ```shell + python run.py + ``` + +3) The resulting depth maps are written to the `output` folder. diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..bba3e3a --- /dev/null +++ b/environment.yml @@ -0,0 +1,90 @@ +name: mixingDatasetsForMonoDepth +channels: + - pytorch + - defaults +dependencies: + - blas=1.0=mkl + - bzip2=1.0.6=h14c3975_5 + - ca-certificates=2019.1.23=0 + - cairo=1.14.12=h8948797_3 + - certifi=2019.3.9=py36_0 + - cffi=1.12.3=py36h2e261b9_0 + - cloudpickle=1.0.0=py_0 + - cudatoolkit=9.0=h13b8566_0 + - cycler=0.10.0=py36_0 + - cytoolz=0.9.0.1=py36h14c3975_1 + - dask-core=1.2.2=py_0 + - dbus=1.13.6=h746ee38_0 + - decorator=4.4.0=py36_1 + - expat=2.2.6=he6710b0_0 + - ffmpeg=4.0=hcdf2ecd_0 + - fontconfig=2.13.0=h9420a91_0 + - freeglut=3.0.0=hf484d3e_5 + - freetype=2.9.1=h8a8886c_1 + - glib=2.56.2=hd408876_0 + - graphite2=1.3.13=h23475e2_0 + - gst-plugins-base=1.14.0=hbbd80ab_1 + - gstreamer=1.14.0=hb453b48_1 + - harfbuzz=1.8.8=hffaf4a1_0 + - hdf5=1.10.2=hba1933b_1 + - icu=58.2=h9c2bf20_1 + - imageio=2.5.0=py36_0 + - intel-openmp=2019.3=199 + - jasper=2.0.14=h07fcdf6_1 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.1.0=py36he6710b0_0 + - libedit=3.1.20181209=hc058e9b_0 + - libffi=3.2.1=hd88cf55_4 + - libgcc-ng=8.2.0=hdf63c60_1 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libglu=9.0.0=hf484d3e_1 + - libopencv=3.4.2=hb342d67_1 + - libopus=1.3=h7b6447c_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=8.2.0=hdf63c60_1 + - libtiff=4.0.10=h2733197_2 + - libuuid=1.0.3=h1bed415_2 + - libvpx=1.7.0=h439df22_0 + - libxcb=1.13=h1bed415_1 + - libxml2=2.9.9=he19cac6_0 + - matplotlib=3.0.3=py36h5429711_0 + - mkl=2019.3=199 + - mkl_fft=1.0.12=py36ha843d7b_0 + - mkl_random=1.0.2=py36hd81dba3_0 + - ncurses=6.1=he6710b0_1 + - networkx=2.3=py_0 + - ninja=1.9.0=py36hfd86e86_0 + - numpy=1.16.3=py36h7e9f1db_0 + - numpy-base=1.16.3=py36hde5b4d6_0 + - olefile=0.46=py36_0 + - opencv=3.4.2=py36h6fd60c2_1 + - openssl=1.1.1b=h7b6447c_1 + - pcre=8.43=he6710b0_0 + - pillow=6.0.0=py36h34e0f95_0 + - pip=19.1.1=py36_0 + - pixman=0.38.0=h7b6447c_0 + - py-opencv=3.4.2=py36hb342d67_1 + - pycparser=2.19=py36_0 + - pyparsing=2.4.0=py_0 + - pyqt=5.9.2=py36h05f1152_2 + - python=3.6.8=h0371630_0 + - python-dateutil=2.8.0=py36_0 + - pytorch=1.1.0=py3.6_cuda9.0.176_cudnn7.5.1_0 + - pytz=2019.1=py_0 + - pywavelets=1.0.3=py36hdd07704_1 + - qt=5.9.7=h5867ecd_1 + - readline=7.0=h7b6447c_5 + - scikit-image=0.15.0=py36he6710b0_0 + - scipy=1.2.1=py36h7c811a0_0 + - setuptools=41.0.1=py36_0 + - sip=4.19.8=py36hf484d3e_0 + - six=1.12.0=py36_0 + - sqlite=3.28.0=h7b6447c_0 + - tk=8.6.8=hbc83047_0 + - toolz=0.9.0=py36_0 + - torchvision=0.2.2=py_3 + - tornado=6.0.2=py36h7b6447c_0 + - wheel=0.33.4=py36_0 + - xz=5.2.4=h14c3975_4 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.3.7=h0b5b093_0 diff --git a/input/.placeholder b/input/.placeholder new file mode 100644 index 0000000..e69de29 diff --git a/monodepth_net.py b/monodepth_net.py new file mode 100644 index 0000000..461db08 --- /dev/null +++ b/monodepth_net.py @@ -0,0 +1,186 @@ +"""MonoDepthNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn +from torchvision import models + + +class MonoDepthNet(nn.Module): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + """ + super().__init__() + + resnet = models.resnet50(pretrained=False) + + self.pretrained = nn.Module() + self.scratch = nn.Module() + self.pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, + resnet.maxpool, resnet.layer1) + + self.pretrained.layer2 = resnet.layer2 + self.pretrained.layer3 = resnet.layer3 + self.pretrained.layer4 = resnet.layer4 + + # adjust channel number of feature maps + self.scratch.layer1_rn = nn.Conv2d(256, features, kernel_size=3, stride=1, padding=1, bias=False) + self.scratch.layer2_rn = nn.Conv2d(512, features, kernel_size=3, stride=1, padding=1, bias=False) + self.scratch.layer3_rn = nn.Conv2d(1024, features, kernel_size=3, stride=1, padding=1, bias=False) + self.scratch.layer4_rn = nn.Conv2d(2048, features, kernel_size=3, stride=1, padding=1, bias=False) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + # adaptive output module: 2 convolutions and upsampling + self.scratch.output_conv = nn.Sequential(nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode='bilinear')) + + # load model + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path) + + self.load_state_dict(parameters) + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=False) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.resConfUnit = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit(xs[1]) + + output = self.resConfUnit(output) + output = nn.functional.interpolate(output, scale_factor=2, + mode='bilinear', align_corners=True) + + return output diff --git a/output/.placeholder b/output/.placeholder new file mode 100644 index 0000000..e69de29 diff --git a/run.py b/run.py new file mode 100644 index 0000000..44e2527 --- /dev/null +++ b/run.py @@ -0,0 +1,71 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import torch +from monodepth_net import MonoDepthNet +import utils + + +def run(input_path, output_path, model_path): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print('initialize') + + # select device + device = torch.device('cpu') + print('device: %s' % device) + + # load network + model = MonoDepthNet(model_path) + model.to(device) + model.eval() + + # get input + img_names = glob.glob(os.path.join(input_path, '*')) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = utils.resize_image(img) + img_input = img_input.to(device) + + # compute + with torch.no_grad(): + out = model.forward(img_input) + + depth = utils.resize_depth(out, img.shape[1], img.shape[0]) + + # output + filename = os.path.join(output_path, os.path.splitext(os.path.basename(img_name))[0]) + utils.write_depth(filename, depth) + + print("finished") + + +if __name__ == '__main__': + # set paths + INPUT_PATH = "input" + OUTPUT_PATH = "output" + MODEL_PATH = "model.pt" + + # set torch options + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # compute depth maps + run(INPUT_PATH, OUTPUT_PATH, MODEL_PATH) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..c077108 --- /dev/null +++ b/utils.py @@ -0,0 +1,181 @@ +"""Utils for monoDepth. +""" +import sys +import re +import numpy as np +import skimage +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, 'rb') as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file: ' + path) + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = '<' + scale = -scale + else: + # big-endian + endian = '>' + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, 'wb') as file: + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n'.encode()) + file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n'.encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = skimage.io.imread(path) + + if img.ndim == 2: + img = skimage.color.gray2rgb(img) + + img = np.float32(img) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = (torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to('cpu') + + depth_resized = cv2.resize(depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC) + + return depth_resized + + +def write_depth(path, depth): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + '.pfm', depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + if depth_max - depth_min > np.finfo('float').eps: + out = 255 * (depth - depth_min) / (depth_max - depth_min) + else: + out = 0 + + skimage.io.imsave(path + '.png', out.astype('uint8')) + + return