-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
262 lines (229 loc) · 9.71 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation script for RegNeRF."""
import functools
from os import path
import time
from absl import app
import flax
from flax.metrics import tensorboard
from flax.training import checkpoints
from internal import configs, datasets, math, models, utils, vis # pylint: disable=g-multiple-import
import jax
from jax import random
import numpy as np
from skimage.metrics import structural_similarity
import tensorflow as tf
######################################################
## ----- This file is modified significantly ------ ##
######################################################
## ---- import neccessary pkgs ---- ##
from lpips import LPIPS
import torch
lpips_vgg = LPIPS(net="vgg").cuda()
## --------------------------------- ##
CENSUS_EPSILON = 1 / 256 # Guard against ground-truth quantization.
configs.define_common_flags()
jax.config.parse_flags_with_absl()
def main(unused_argv):
tf.config.experimental.set_visible_devices([], 'GPU')
tf.config.experimental.set_visible_devices([], 'TPU')
config = configs.load_config(save_config=False)
if config.use_wandb:
import wandb
wandb.init(project=config.project, entity=config.entity, sync_tensorboard=True)
wandb.run.name = config.expname
wandb.run.save()
wandb.config.update(config)
dataset = datasets.load_dataset('test', config.data_dir, config)
model, init_variables = models.construct_mipnerf(
random.PRNGKey(20200823),
dataset.peek()['rays'],
config)
optimizer = flax.optim.Adam(config.lr_init).create(init_variables)
state = utils.TrainState(optimizer=optimizer)
del optimizer, init_variables
def ssim_fn(x, y):
## ---- fix the SSIM issue default in regnerf's code ---- ##
return structural_similarity(x, y, multichannel=True, data_range=1.0)
## ------------------------------------------------------ ##
def lpips_fn(x, y):
score = lpips_vgg(torch.from_numpy(np.array(x)).cuda().permute(2, 0, 1).unsqueeze(0),
torch.from_numpy(np.array(y)).cuda().permute(2, 0, 1).unsqueeze(0))
return score.item()
census_fn = jax.jit(
functools.partial(math.compute_census_err, epsilon=CENSUS_EPSILON))
## ----- comment out as we use lpips instead ----- ##
# print('WARNING: LPIPS calculation not supported. NaN values used instead.')
# if config.eval_disable_lpips:
# lpips_fn = lambda x, y: np.nan
# else:
# lpips_fn = lambda x, y: np.nan
## ------------------------------------------------ ##
last_step = 0
out_dir = path.join(config.checkpoint_dir,
'path_renders' if config.render_path else 'test_preds')
path_fn = lambda x: path.join(out_dir, x)
summary_writer = tensorboard.SummaryWriter(
path.join(config.checkpoint_dir, 'eval'))
# Fix for loading pre-trained models.
try:
state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)
except: # pylint: disable=bare-except
print('Using pre-trained model.')
state_dict = checkpoints.restore_checkpoint(config.checkpoint_dir, None)
for i in [9, 17]:
del state_dict['optimizer']['target']['params']['MLP_0'][f'Dense_{i}']
state_dict['optimizer']['target']['params']['MLP_0'][
'Dense_9'] = state_dict['optimizer']['target']['params']['MLP_0'][
'Dense_18']
state_dict['optimizer']['target']['params']['MLP_0'][
'Dense_10'] = state_dict['optimizer']['target']['params']['MLP_0'][
'Dense_19']
state_dict['optimizer']['target']['params']['MLP_0'][
'Dense_11'] = state_dict['optimizer']['target']['params']['MLP_0'][
'Dense_20']
del state_dict['optimizerd']
state = flax.serialization.from_state_dict(state, state_dict)
step = int(state.optimizer.state.step)
if config.freq_reg:
# Rendering is forced to be deterministic even if training was randomized, as
# this eliminates 'speckle' artifacts.
freq_reg_mask = (
math.get_freq_reg_mask(99, step, config.freq_reg_end, config.max_vis_freq_ratio),
math.get_freq_reg_mask(27, step, config.freq_reg_end, config.max_vis_freq_ratio))
def render_eval_fn(variables, _, rays):
return jax.lax.all_gather(
model.apply(
variables,
None, # Deterministic.
rays,
resample_padding=config.resample_padding_final,
compute_extras=True,
freq_reg_mask=freq_reg_mask)[0], axis_name='batch')
else:
def render_eval_fn(variables, _, rays):
return jax.lax.all_gather(
model.apply(
variables,
None, # Deterministic.
rays,
resample_padding=config.resample_padding_final,
compute_extras=True)[0], axis_name='batch')
# pmap over only the data input.
render_eval_pfn = jax.pmap(
render_eval_fn,
in_axes=(None, None, 0),
donate_argnums=2,
axis_name='batch',
)
if step <= last_step:
print(f'Checkpoint step {step} <= last step {last_step}, exit.')
exit()
print(f'Evaluating checkpoint at step {step}.')
if config.eval_save_output and (not utils.isdir(out_dir)):
utils.makedirs(out_dir)
key = random.PRNGKey(0 if config.deterministic_showcase else step)
perm = random.permutation(key, dataset.size)
showcase_indices = np.sort(perm[:config.num_showcase_images])
metrics = []
showcases = []
for idx in range(dataset.size):
print(f'Evaluating image {idx+1}/{dataset.size}')
eval_start_time = time.time()
batch = next(dataset)
rendering = models.render_image(
functools.partial(render_eval_pfn, state.optimizer.target),
batch['rays'],
None,
config)
print(f'Rendered in {(time.time() - eval_start_time):0.3f}s')
if jax.host_id() != 0: # Only record via host 0.
continue
showcases.append((idx, rendering, batch))
if not config.render_path:
metric = {}
metric['psnr'] = float(
math.mse_to_psnr(((rendering['rgb'] - batch['rgb'])**2).mean()))
metric['ssim'] = float(ssim_fn(rendering['rgb'], batch['rgb']))
metric['lpips'] = float(lpips_fn(rendering['rgb'], batch['rgb']))
metric['avg_err'] = float(
math.compute_avg_error(
psnr=metric['psnr'],
ssim=metric['ssim'],
lpips=metric['lpips'],
))
metric['census_err'] = float(census_fn(rendering['rgb'], batch['rgb']))
if config.compute_disp_metrics:
disp = 1 / (1 + rendering['distance_mean'])
metric['disp_mse'] = float(((disp - batch['disps'])**2).mean())
if config.compute_normal_metrics:
one_eps = 1 - np.finfo(np.float32).eps
metric['normal_mae'] = float(
np.arccos(
np.clip(
np.sum(batch['normals'] * rendering['normals'], axis=-1),
-one_eps, one_eps)).mean())
if config.dataset_loader == 'dtu':
rgb = batch['rgb']
rgb_hat = rendering['rgb']
mask = batch['mask']
mask_bin = (mask == 1.)
rgb_fg = rgb * mask + (1 - mask)
rgb_hat_fg = rgb_hat * mask + (1 - mask)
metric['psnr_masked'] = float(
math.mse_to_psnr(((rgb - rgb_hat)[mask_bin]**2).mean()))
metric['ssim_masked'] = float(ssim_fn(rgb_hat_fg, rgb_fg))
metric['lpips_masked'] = float(lpips_fn(rgb_hat_fg, rgb_fg))
metric['avg_err_masked'] = float(
math.compute_avg_error(
psnr=metric['psnr_masked'],
ssim=metric['ssim_masked'],
lpips=metric['lpips_masked'],
))
for m, v in metric.items():
print(f'{m:10s} = {v:.4f}')
metrics.append(metric)
if config.eval_save_output and (config.eval_render_interval > 0):
if (idx % config.eval_render_interval) == 0:
utils.save_img_u8(rendering['rgb'], path_fn(f'color_{idx:03d}.png'))
utils.save_img_u8(rendering['normals'] / 2. + 0.5,
path_fn(f'normals_{idx:03d}.png'))
utils.save_img_f32(rendering['distance_mean'],
path_fn(f'distance_mean_{idx:03d}.tiff'))
utils.save_img_f32(rendering['distance_median'],
path_fn(f'distance_median_{idx:03d}.tiff'))
utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff'))
if jax.host_id() == 0:
for name in list(metrics[0].keys()):
summary_writer.scalar(name, np.mean([m[name] for m in metrics]), step)
if config.use_wandb and config.log_img_to_wandb:
for i, r, b in showcases:
for k, v in vis.visualize_suite(r, b['rays'], config).items():
summary_writer.image(f'pred_{k}_{i}', v, step)
if not config.render_path:
summary_writer.image(f'target_{i}', b['rgb'], step)
if (config.eval_save_output and (not config.render_path) and
(jax.host_id() == 0)):
print('#####################')
for name in list(metrics[0].keys()):
with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:
f.write(' '.join([str(m[name]) for m in metrics]))
print(f'{name}:', np.mean([m[name] for m in metrics]))
print('evaluated exp:', config.expname)
if config.use_wandb:
wandb.finish()
if __name__ == '__main__':
app.run(main)