forked from nehamjain10/sea-thru_neham
-
Notifications
You must be signed in to change notification settings - Fork 1
/
seathru-mono-e2e.py
125 lines (105 loc) · 5.45 KB
/
seathru-mono-e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import absolute_import, division, print_function
import os
import sys
import glob
import argparse
import time
import io
import numpy as np
import PIL.Image as pil
import matplotlib as mpl
import matplotlib.cm as cm
import pynng
from pynng import nng
import torch
from torchvision import transforms, datasets
import deps.monodepth2.networks as networks
from deps.monodepth2.layers import disp_to_depth
from deps.monodepth2.utils import download_model_if_doesnt_exist
from seathru import *
def run(args):
"""Function to predict for a single image or folder of images
"""
assert args.model_name is not None, \
"You must specify the --model_name parameter; see README.md for an example"
if torch.cuda.is_available() and not args.no_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
download_model_if_doesnt_exist(args.model_name)
model_path = os.path.join("models", args.model_name)
print("-> Loading model from ", model_path)
encoder_path = os.path.join(model_path, "encoder.pth")
depth_decoder_path = os.path.join(model_path, "depth.pth")
# LOADING PRETRAINED MODEL
print(" Loading pretrained encoder")
encoder = networks.ResnetEncoder(18, False)
loaded_dict_enc = torch.load(encoder_path, map_location=device)
# extract the height and width of image that this model was trained with
feed_height = loaded_dict_enc['height']
feed_width = loaded_dict_enc['width']
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
encoder.load_state_dict(filtered_dict_enc)
encoder.to(device)
encoder.eval()
print(" Loading pretrained decoder")
depth_decoder = networks.DepthDecoder(
num_ch_enc=encoder.num_ch_enc, scales=range(4))
loaded_dict = torch.load(depth_decoder_path, map_location=device)
depth_decoder.load_state_dict(loaded_dict)
depth_decoder.to(device)
depth_decoder.eval()
# Load image and preprocess
img = Image.fromarray(rawpy.imread(args.image).postprocess()) if args.raw else pil.open(args.image).convert('RGB')
img.thumbnail((args.size, args.size), Image.ANTIALIAS)
original_width, original_height = img.size
# img = exposure.equalize_adapthist(np.array(img), clip_limit=0.03)
# img = Image.fromarray((np.round(img * 255.0)).astype(np.uint8))
input_image = img.resize((feed_width, feed_height), pil.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
print('Preprocessed image', flush=True)
# PREDICTION
input_image = input_image.to(device)
features = encoder(input_image)
outputs = depth_decoder(features)
disp = outputs[("disp", 0)]
disp_resized = torch.nn.functional.interpolate(
disp, (original_height, original_width), mode="bilinear", align_corners=False)
# Saving colormapped depth image
disp_resized_np = disp_resized.squeeze().cpu().detach().numpy()
mapped_im_depths = ((disp_resized_np - np.min(disp_resized_np)) / (
np.max(disp_resized_np) - np.min(disp_resized_np))).astype(np.float32)
print("Processed image", flush=True)
print('Loading image...', flush=True)
depths = preprocess_monodepth_depth_map(mapped_im_depths, args.monodepth_add_depth,
args.monodepth_multiply_depth)
recovered = run_pipeline(np.array(img) / 255.0, depths, args)
# recovered = exposure.equalize_adapthist(scale(np.array(recovered)), clip_limit=0.03)
sigma_est = estimate_sigma(recovered, multichannel=True, average_sigmas=True) / 10.0
recovered = denoise_tv_chambolle(recovered, sigma_est, multichannel=True)
im = Image.fromarray((np.round(recovered * 255.0)).astype(np.uint8))
im.save(args.output, format='png')
print('Done.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image', required=True, help='Input image')
parser.add_argument('--output', default='output.png', help='Output filename')
parser.add_argument('--f', type=float, default=2.0, help='f value (controls brightness)')
parser.add_argument('--l', type=float, default=0.5, help='l value (controls balance of attenuation constants)')
parser.add_argument('--p', type=float, default=0.01, help='p value (controls locality of illuminant map)')
parser.add_argument('--min-depth', type=float, default=0.0,
help='Minimum depth value to use in estimations (range 0-1)')
parser.add_argument('--max-depth', type=float, default=1.0,
help='Replacement depth percentile value for invalid depths (range 0-1)')
parser.add_argument('--spread-data-fraction', type=float, default=0.05,
help='Require data to be this fraction of depth range away from each other in attenuation estimations')
parser.add_argument('--size', type=int, default=320, help='Size to output')
parser.add_argument('--monodepth-add-depth', type=float, default=2.0, help='Additive value for monodepth map')
parser.add_argument('--monodepth-multiply-depth', type=float, default=10.0,
help='Multiplicative value for monodepth map')
parser.add_argument('--model-name', type=str, default="mono_1024x320",
help='monodepth model name')
parser.add_argument('--output-graphs', action='store_true', help='Output graphs')
parser.add_argument('--raw', action='store_true', help='RAW image')
args = parser.parse_args()
run(args)