forked from dontLoveBugs/FCRN_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
126 lines (100 loc) · 5.53 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
# @Time : 2018/10/21 20:57
# @Author : Wang Xin
# @Email : [email protected]
import glob
import os
import torch
import shutil
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
cmap = plt.cm.jet
def parse_command():
modality_names = ['rgb', 'rgbd', 'd']
import argparse
parser = argparse.ArgumentParser(description='FCRN')
parser.add_argument('--decoder', default='upproj', type=str)
parser.add_argument('--resume',
default=None,
type=str, metavar='PATH',
help='path to latest checkpoint (default: ./run/run_1/checkpoint-5.pth.tar)')
parser.add_argument('-b', '--batch-size', default=16, type=int, help='mini-batch size (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
help='number of total epochs to run (default: 15)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate (default 0.0001)')
parser.add_argument('--lr_patience', default=2, type=int, help='Patience of LR scheduler. '
'See documentation of ReduceLROnPlateau.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
help='number of data loading workers (default: 10)')
parser.add_argument('--dataset', type=str, default="nyu")
parser.add_argument('--manual_seed', default=1, type=int, help='Manually set random seed')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--dataset-dir', required=True, type=str, help='path to dataset')
parser.add_argument('--upper-limit', default=1.0e6, type=float, help='upper limit of depth')
parser.add_argument('--validate-only', action='store_true', help='Perform validation only')
parser.add_argument('--loss-func', default='l1', type=str, help='loss function', choices=['l1', 'berhu'])
parser.add_argument('--average-lr', action='store_true', help='use a unified learning rate for all layers')
args = parser.parse_args()
return args
def get_output_directory(args):
if args.resume:
return os.path.dirname(args.resume)
else:
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
save_dir_root = os.path.join(save_dir_root, 'result', args.decoder)
runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*')), key=lambda x: int(x.split('_')[-1]))
run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
save_dir = os.path.join(save_dir_root, 'run_' + str(run_id))
return save_dir
# 保存检查点
def save_checkpoint(state, is_best, epoch, output_directory):
checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar')
torch.save(state, checkpoint_filename)
if is_best:
best_filename = os.path.join(output_directory, 'model_best.pth.tar')
shutil.copyfile(checkpoint_filename, best_filename)
def colored_depthmap(depth, d_min=None, d_max=None):
if d_min is None:
d_min = np.min(depth)
if d_max is None:
d_max = np.max(depth)
depth_relative = (depth - d_min) / (d_max - d_min)
return 255 * cmap(depth_relative)[:, :, :3] # H, W, C
def merge_into_row(input, depth_target, depth_pred, args):
rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) # H, W, C
depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
upper_limit_mask_cpu = depth_target_cpu > args.upper_limit
depth_target_cpu[upper_limit_mask_cpu] = 0.0
depth_pred_cpu[upper_limit_mask_cpu] = 0.0
depth_pred_cpu = np.clip(depth_pred_cpu, 0, args.upper_limit)
d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu))
d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu))
depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
img_merge = np.hstack([rgb, depth_target_col, depth_pred_col])
return img_merge
def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred):
rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) # H, W, C
depth_input_cpu = np.squeeze(depth_input.cpu().numpy())
depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu))
d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu))
depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max)
depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col])
return img_merge
def add_row(img_merge, row):
return np.vstack([img_merge, row])
def save_image(img_merge, filename):
img_merge = Image.fromarray(img_merge.astype('uint8'))
img_merge.save(filename)