diff --git a/.splatter.py.swp b/.splatter.py.swp new file mode 100644 index 0000000..bd3f467 Binary files /dev/null and b/.splatter.py.swp differ diff --git a/renderer.py b/renderer.py index 96c7d68..05d1ed3 100644 --- a/renderer.py +++ b/renderer.py @@ -78,7 +78,7 @@ def forward(ctx, x): @staticmethod def backward(ctx, g): x = ctx.saved_tensors[0] - return g * torch.exp(x.clamp(-3, 1)) + return g * torch.exp(x.clamp(-1, 1)) trunc_exp = _trunc_exp.apply diff --git a/splatter.py b/splatter.py index 6d76c33..4fe9060 100644 --- a/splatter.py +++ b/splatter.py @@ -57,10 +57,10 @@ def to_scale_matrix(self): def _tocpp(self): _cobj = gaussian.Gaussian3ds() - _cobj.pos = self.pos - _cobj.rgb = self.rgb - _cobj.opa = self.opa - _cobj.cov = self.cov + _cobj.pos = self.pos.clone() + _cobj.rgb = self.rgb.clone() + _cobj.opa = self.opa.clone() + _cobj.cov = self.cov.clone() return _cobj def to(self, *args, **kwargs): @@ -93,12 +93,18 @@ def filte(self, mask): cov=self.cov[mask], ) - def get_gaussian_3d_conv(self): + def get_gaussian_3d_cov(self, scale_activation="abs"): R = q2r(self.quat) - _scale = self.scale.abs()+1e-4 + if scale_activation == "abs": + _scale = self.scale.abs()+1e-4 + elif scale_activation == "exp": + _scale = trunc_exp(self.scale) + else: + print("No support scale activation") + exit() + # _scale = trunc_exp(self.scale) # _scale = torch.clamp(_scale, min=1e-4, max=0.1) S = torch.diag_embed(_scale) - # S = torch.diag_embed(trunc_exp(self.scale)) RS = torch.bmm(R, S) RSSR = torch.bmm(RS, RS.permute(0,2,1)) return RSSR @@ -109,15 +115,22 @@ def get_gaussian_3d_conv(self): def reset_opa(self): torch.nn.init.uniform_(self.opa, a=inverse_sigmoid(0.1), b=inverse_sigmoid(0.11)) - def adaptive_control(self, grad, taus, delete_thresh): + def adaptive_control(self, grad, taus, delete_thresh, scale_activation="abs"): # grad: B x 3 # densification # 1. delete gaussians with small opacities assert self.init_values # only for the initial gaussians - print(inverse_sigmoid(0.1)) + print(inverse_sigmoid(0.01)) print(self.opa.min()) print(self.opa.max()) - _mask = (self.opa > inverse_sigmoid(0.1)) & (self.scale.norm(dim=-1) < delete_thresh) + if scale_activation == "abs": + _mask = (self.opa > inverse_sigmoid(0.01)) & (self.scale.norm(dim=-1) < delete_thresh) + elif scale_activation == "exp": + _mask = (self.opa > inverse_sigmoid(0.01)) & (self.scale.exp().norm(dim=-1) < delete_thresh) + else: + print("Wrong activation") + exit() + self.pos = nn.parameter.Parameter(self.pos.detach()[_mask]) self.rgb = nn.parameter.Parameter(self.rgb.detach()[_mask]) self.opa = nn.parameter.Parameter(self.opa.detach()[_mask]) @@ -127,13 +140,13 @@ def adaptive_control(self, grad, taus, delete_thresh): print("DELETE: {} Gaussians".format((~_mask).sum())) # 2. clone or split densify_mask = grad.abs().max(-1)[0] > 0.0002 - cat_pos = [self.pos.detach()] - cat_rgb = [self.rgb.detach()] - cat_opa = [self.opa.detach()] - cat_quat = [self.quat.detach()] - cat_scale = [self.scale.detach()] + cat_pos = [self.pos.clone().detach()] + cat_rgb = [self.rgb.clone().detach()] + cat_opa = [self.opa.clone().detach()] + cat_quat = [self.quat.clone().detach()] + cat_scale = [self.scale.clone().detach()] if densify_mask.any(): - scale_norm = self.scale.norm(dim=-1) + scale_norm = self.scale.norm(dim=-1) if scale_activation == "abs" else self.scale.exp().norm(dim=-1) split_mask = scale_norm > taus clone_mask = scale_norm <= taus split_mask = split_mask & densify_mask @@ -141,7 +154,7 @@ def adaptive_control(self, grad, taus, delete_thresh): if clone_mask.any(): cloned_pos = self.pos[clone_mask].clone().detach() # cloned_pos -= grad[clone_mask] * 0.01 - cloned_pos -= grad[clone_mask] * 10 + cloned_pos -= grad[clone_mask] * 0.01 cloned_rgb = self.rgb[clone_mask].clone().detach() cloned_opa = self.opa[clone_mask].clone().detach() cloned_quat = self.quat[clone_mask].clone().detach() @@ -154,24 +167,33 @@ def adaptive_control(self, grad, taus, delete_thresh): cat_scale.append(cloned_scale) if split_mask.any(): - _scale = self.scale.detach() - _scale[split_mask] /= 1.6 - - cat_scale[0][split_mask] /= 1.6 - self.scale = nn.parameter.Parameter(_scale) + _scale = self.scale.clone().detach() + if scale_activation == "abs": + _scale[split_mask] /= 1.6 + elif scale_activation == "exp": + _scale[split_mask] -= math.log(1.6) + else: + print("Wrong activation") + exit() + + cat_scale[0] = _scale + # cat_scale[0][split_mask] /= 1.6 + # self.scale = nn.parameter.Parameter(_scale) # sampling two positions - this_cov = self.get_gaussian_3d_conv()[split_mask] + this_cov = self.get_gaussian_3d_cov(scale_activation=scale_activation)[split_mask] p1, p2 = sample_two_point(self.pos[split_mask], this_cov) # split_pos = self.pos[split_mask].clone().detach() # split_pos -= grad[split_mask] * 0.01 - cat_pos[0][split_mask] = p1.detach() + origin_pos = cat_pos[0] + origin_pos[split_mask] = p1.detach() + cat_pos[0] = origin_pos split_pos = p2.detach() split_rgb = self.rgb[split_mask].clone().detach() split_opa = self.opa[split_mask].clone().detach() split_quat = self.quat[split_mask].clone().detach() - split_scale = self.scale[split_mask].clone().detach() + split_scale = _scale[split_mask].clone() print("SPLIT : {} Gaussians".format(split_pos.shape[0])) cat_pos.append(split_pos) cat_rgb.append(split_rgb) @@ -184,7 +206,7 @@ def adaptive_control(self, grad, taus, delete_thresh): self.quat = nn.parameter.Parameter(torch.cat(cat_quat)) self.scale = nn.parameter.Parameter(torch.cat(cat_scale)) - def project(self, rot, tran, near, jacobian_calc): + def project(self, rot, tran, near, jacobian_calc, scale_activation="abs"): with Timer(" w2c"): pos_cam_space = world_to_camera(self.pos, rot, tran) @@ -198,9 +220,12 @@ def project(self, rot, tran, near, jacobian_calc): else: jacobian = jacobian_torch(pos_cam_space) - with Timer(" cov"): - gaussian_3d_cov = self.get_gaussian_3d_conv() - JW = torch.einsum("bij,bjk->bik", jacobian, rot.unsqueeze(dim=0)) + with Timer(" cov1"): + gaussian_3d_cov = self.get_gaussian_3d_cov(scale_activation=scale_activation) + # JW = torch.einsum("bij,bjk->bik", jacobian, rot.unsqueeze(dim=0)) + with Timer(" cov2"): + JW = torch.matmul(jacobian, rot.unsqueeze(dim=0)) + with Timer(" cov3"): gaussian_2d_cov = torch.bmm(torch.bmm(JW, gaussian_3d_cov), JW.permute(0,2,1))[:, :2, :2] with Timer(" last"): @@ -229,8 +254,8 @@ def crop(self, image): # output: height x width x 3 top = int(self.padded_height - self.height)//2 left = int(self.padded_width - self.width)//2 - # return image[top:top+int(self.height), left:left+int(self.width), :] - return image[top:top+int(self.height)-1, left:left+int(self.width)-1, :] + #return image[top:top+int(self.height), left:left+int(self.width), :] + return image[top:top+int(self.height)-1, left:left+int(self.width), :] def create_tiles(self): self.tiles_left = torch.linspace(-self.padded_width/2, self.padded_width/2, self.n_tile_x + 1, device=self.device)[:-1] @@ -263,7 +288,6 @@ def create_tiles(self): def __len__(self): return self.tiles_top.shape[0] - class Splatter(nn.Module): def __init__(self, colmap_path, @@ -280,6 +304,8 @@ def __init__(self, tile_culling_dist_thresh=0.5, tile_culling_prob_thresh=0.1, debug=1, + scale_activation="abs", + cudaculling=0, ): super().__init__() self.device = torch.device("cuda") @@ -292,6 +318,8 @@ def __init__(self, self.tile_culling_dist_thresh = tile_culling_dist_thresh self.tile_culling_prob_thresh = tile_culling_prob_thresh self.debug = debug + self.scale_activation = scale_activation + self.cudaculling = cudaculling assert jacobian_calc == "cuda" or jacobian_calc == "torch" self.points3d = read_points3d_binary(os.path.join(colmap_path, "points3D.bin")) @@ -313,13 +341,26 @@ def __init__(self, rgb = torch.stack(_rgbs).to(torch.float32).to(self.device) # B x 3 if self.use_sh_coeff: rgb = initialize_sh(rgb) + + _pos=torch.stack(_points).to(torch.float32).to(self.device) + mean_min_three_dis = [] + for i_pos in tqdm(range(_pos.shape[0])): + _r = (_pos[i_pos:i_pos+1] - _pos).norm(dim=-1).sort(dim=-1)[0][1:4].mean().item() + mean_min_three_dis.append(_r) + mean_min_three_dis = torch.Tensor(mean_min_three_dis).to(torch.float32) * scale_init_value + + if scale_activation == "exp": + mean_min_three_dis = mean_min_three_dis.log() + self.gaussian_3ds = Gaussian3ds( - pos=torch.stack(_points).to(torch.float32).to(self.device), # B x 3 + pos=_pos.to(self.device), # B x 3 rgb = rgb, # B x 3 or 27 opa = torch.ones(len(_points)).to(torch.float32).to(self.device)*inverse_sigmoid(opa_init_value), # B quat = torch.Tensor([1, 0, 0, 0]).unsqueeze(dim=0).repeat(len(_points),1).to(torch.float32).to(self.device), # B x 4 # quat = torch.Tensor([1, 1, 2, 1]).unsqueeze(dim=0).repeat(len(_points),1).to(torch.float32).to(self.device), # B x 4 - scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*scale_init_value, + #scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*math.log(scale_init_value), + #scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*scale_init_value, + scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*mean_min_three_dis.unsqueeze(dim=1).to(self.device), init_values=True, ) self.current_camera = None @@ -372,65 +413,61 @@ def set_camera(self, idx): # print("current_camera info") # print(self.current_camera) - def project_and_culling(self, cudaculling=False): + def project_and_culling(self): # project 3D to 2D # print(f"number of gaussians {len(self.gaussian_3ds.pos)}") # self.tic.record() - with Timer(" frustum pytorch", verbose=True): - with Timer(" frustum 11"): - gaussian_3ds_pos_camera_space = world_to_camera(self.gaussian_3ds.pos, self.current_w2c_rot, self.current_w2c_tran) - with Timer(" frustum 12"): - valid = gaussian_3ds_pos_camera_space[:,2] > self.near - gaussian_3ds_pos_image_space = camera_to_image(gaussian_3ds_pos_camera_space[valid]) - culling_mask = (gaussian_3ds_pos_image_space[:, 0].abs() < (self.current_camera.width/2/self.current_camera.params[0])) & \ - (gaussian_3ds_pos_image_space[:, 1].abs() < (self.current_camera.height/2/self.current_camera.params[1])) - with Timer(" frustum 13"): - self.gaussian_3ds_valid = self.gaussian_3ds.filte(valid)#.filte(culling_mask) - with Timer(" frustum 2"): - self.culling_gaussian_3d_image_space = self.gaussian_3ds_valid.project(self.current_w2c_rot, self.current_w2c_tran, self.near, self.jacobian_calc).filte(culling_mask) - # self.culling_gaussian_3d_image_space = self.gaussian_3ds.project(self.current_w2c_rot, self.current_w2c_tran, self.near, self.jacobian_calc).filte(valid).filte(culling_mask) - # culling - # with Timer(" frumstum 3"): - # self.culling_gaussian_3d_image_space = self.gaussian_3ds_image_space.filte(culling_mask) - # with Timer("step11"): - # pos_camera_space = world_to_camera(self.gaussian_3ds.pos, self.current_w2c_rot, self.current_w2c_tran) - # with Timer("step12"): - # valid = pos_camera_space[:,2] > self.near - # with Timer("step13"): - # pos_img_space = camera_to_image(pos_camera_space) - # with Timer("step14"): - # culling_mask = (pos_img_space[:, 0].abs() < (self.current_camera.width/2/self.current_camera.params[0])) & \ - # (pos_img_space[:, 1].abs() < (self.current_camera.height/2/self.current_camera.params[1])) - # with Timer("step15"): - # # calc 2d cov - # valid = valid & culling_mask - # print(valid.sum(0)) - - # with Timer("step2"): - # jacobian = torch.empty(valid.sum(0), 3, 3, device=self.device) - # gaussian.jacobian(pos_camera_space[valid], jacobian) - - # R = q2r(self.gaussian_3ds.quat[valid]) - # _scale = (self.gaussian_3ds.scale[valid]).abs()+1e-4 - # S = torch.diag_embed(_scale) - # RS = torch.bmm(R, S) - # RSSR = torch.bmm(RS, RS.permute(0,2,1)) - - # JW = torch.einsum("bij,bjk->bik", jacobian, self.current_w2c_rot.unsqueeze(dim=0)) - # cov_2d = torch.bmm(torch.bmm(JW, RSSR), JW.permute(0,2,1))[:, :2, :2] - - # with Timer("step3"): - # self.culling_gaussian_3d_image_space = Gaussian3ds( - # pos=pos_img_space[valid], - # rgb=self.gaussian_3ds.rgb.sigmoid()[valid], - # opa=self.gaussian_3ds.opa.sigmoid()[valid], - # cov=cov_2d, - # ) + if self.cudaculling: + with Timer(" frustum cuda"): + # self.gaussian_3ds.rgb = self.gaussian_3ds.rgb.sigmoid() + # self.gaussian_3ds.opa = self.gaussian_3ds.opa.sigmoid() + # self.gaussian_3ds.quat = self.gaussian_3ds.quat / self.gaussian_3ds.quat.norm(dim=-1, keepdim=True) + # self.gaussian_3ds.scale = self.gaussian_3ds.scale.abs()+1e-5 + normed_quat = (self.gaussian_3ds.quat/self.gaussian_3ds.quat.norm(dim=-1, keepdim=True)) + if self.scale_activation == "abs": + normed_scale = self.gaussian_3ds.scale.abs()+1e-6 + else: + assert self.scale_activation == "exp" + normed_scale = trunc_exp(self.gaussian_3ds.scale) + _pos, _cov, _culling_mask = global_culling( + self.gaussian_3ds.pos, + normed_quat, + normed_scale, + self.current_w2c_rot.detach(), + self.current_w2c_tran.detach(), + self.near, + self.current_camera.width/2/self.current_camera.params[0], + self.current_camera.height/2/self.current_camera.params[1], + ) + + self.culling_gaussian_3d_image_space = Gaussian3ds( + pos=_pos[_culling_mask.bool()], + cov=_cov[_culling_mask.bool()], + rgb=self.gaussian_3ds.rgb[_culling_mask.bool()].sigmoid(), + opa=self.gaussian_3ds.opa[_culling_mask.bool()].sigmoid(), + ) + # print(len(self.culling_gaussian_3d_image_space.pos)) + else: + gaussian_3ds_pos_camera_space = world_to_camera(self.gaussian_3ds.pos, self.current_w2c_rot, self.current_w2c_tran) + valid = gaussian_3ds_pos_camera_space[:,2] > self.near + gaussian_3ds_pos_image_space = camera_to_image(gaussian_3ds_pos_camera_space) + culling_mask = (gaussian_3ds_pos_image_space[:, 0].abs() < (self.current_camera.width*1.2/2/self.current_camera.params[0])) & \ + (gaussian_3ds_pos_image_space[:, 1].abs() < (self.current_camera.height*1.2/2/self.current_camera.params[1])) + valid &= culling_mask + self.gaussian_3ds_valid = self.gaussian_3ds.filte(valid) + self.culling_gaussian_3d_image_space = self.gaussian_3ds_valid.project( + self.current_w2c_rot, + self.current_w2c_tran, + self.near, + self.jacobian_calc, + scale_activation=self.scale_activation, + ) def render(self, out_write=True): # self.tic.record() with Timer(" culling tiles", debug=self.debug): tile_n_point = torch.zeros(len(self.tile_info), device=self.device, dtype=torch.int32) + # print("***{}***".format(len(self.culling_gaussian_3d_image_space.pos)//10)) tile_gaussian_list = torch.ones(len(self.tile_info), len(self.culling_gaussian_3d_image_space.pos)//10, device=self.device, dtype=torch.int32) * -1 _method_config = {"dist": 0, "prob": 1, "prob2": 2} gaussian.calc_tile_list( @@ -450,6 +487,7 @@ def render(self, out_write=True): self.tile_info.topmost, # 0.01, ) + # print(tile_gaussian_list.sum()) with Timer(" gather culled tiles", debug=self.debug): gathered_list = torch.empty(tile_n_point.sum(), dtype=torch.int32, device=self.device) @@ -498,14 +536,14 @@ def render(self, out_write=True): return rendered_image - def forward(self, camera_id=None, record_view_space_pos_grad=False, cudaculling=False): + def forward(self, camera_id=None, record_view_space_pos_grad=False): # print(self.gaussian_3ds.opa.max()) # print(self.gaussian_3ds.opa.min()) with Timer("forward", debug=self.debug): with Timer("set camera"): self.set_camera(camera_id) with Timer("frustum culling", debug=self.debug): - self.project_and_culling(cudaculling) + self.project_and_culling() with Timer("render function", debug=self.debug): padded_render_img = self.render(out_write=False) with Timer("crop", debug=self.debug): diff --git a/src/gaussian.cu b/src/gaussian.cu index e0b07c1..96aecd8 100644 --- a/src/gaussian.cu +++ b/src/gaussian.cu @@ -359,7 +359,6 @@ void gather_gaussians( max_points_for_tile, gaussian_list_size ); - } template @@ -470,7 +469,7 @@ __global__ void draw_backward_kernel( global_idx = i_loadings*SMSIZE + i; // check bound and early stop - if(global_idx>=(end_idx-start_idx)||accum < 0.01){ + if(global_idx>=(end_idx-start_idx)||accum < 0.0001){ break; } _a = _gaussian_cov[i*4+0]; @@ -666,7 +665,7 @@ __global__ void draw_kernel( global_idx = i_loadings*SMSIZE + i; // check bound and early stop - if(global_idx>=(end_idx-start_idx)||accum < 0.01){ + if(global_idx>=(end_idx-start_idx)||accum < 0.0001){ break; } _a = _gaussian_cov[i*4+0]; diff --git a/test.py b/test.py index 290828d..1ad8681 100644 --- a/test.py +++ b/test.py @@ -1,26 +1,10 @@ import torch -import fmatmul as mul -import time +from utils import Timer +import gaussian -K = 4096 -scale = 1 -version = 4 -dtype=torch.float -a = scale * torch.randn(K, K, dtype=dtype, device='cuda') -b = scale * torch.randn(K, K, dtype=dtype, device='cuda') -c = torch.zeros(K, K, dtype=dtype, device='cuda') -tic = torch.cuda.Event(enable_timing=True) -toc = torch.cuda.Event(enable_timing=True) -_rep = 10 -mul.matmul(a, b, c, version) -tic.record() -for i in range(_rep): - mul.matmul(a, b, c, version) -toc.record() -torch.cuda.synchronize() -elapse = tic.elapsed_time(toc) / 1000 -print(elapse) -flops = (2*K**3 * 1e-9 * _rep)/elapse -print(flops) -print(((c-a@b).abs()>1e-2).sum()) \ No newline at end of file +a = torch.randn(130000, 3).cuda() +b = torch.randn(1,3,3).cuda() +for i in range(10): + with Timer("@", verbose=True): + c = gaussian.world2camera(a,b) diff --git a/train.py b/train.py index 4ff3dca..eddd63c 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,19 @@ from torchmetrics.functional import peak_signal_noise_ratio as psnr_func from utils import Timer +@torch.no_grad() +def test(gaussian_splatter, camera_id): + ssim_criterion = SSIM(window_size=11, reduction='mean') + gaussian_splatter.eval() + rendered_img = gaussian_splatter(camera_id) + psnr = psnr_func(rendered_img, gaussian_splatter.ground_truth) + ssim = ssim_criterion( + rendered_img.unsqueeze(0).permute(0, 3, 1, 2), + gaussian_splatter.ground_truth.unsqueeze(0).permute(0, 3, 1, 2).to(rendered_img), + ) + gaussian_splatter.train() + return psnr, ssim + def train(gaussian_splatter, opt): # lr = 0.05 lr_opa = opt.lr * opt.lr_factor_for_opa @@ -17,9 +30,10 @@ def train(gaussian_splatter, opt): lr_quat = opt.lr * opt.lr_factor_for_quat lr_scale = opt.lr * opt.lr_factor_for_scale lrs = [lr_opa, lr_rgb, lr_pos, lr_scale, lr_quat] - - warmup_iters = 100 - lr_lambda = lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 1 * 0.2**((i_iter-warmup_iters) // 2000) + + warmup_iters = opt.n_iters_warmup + lr_lambda = lambda i_iter: i_iter / warmup_iters if i_iter <= warmup_iters else 0.2**((i_iter-warmup_iters) // 2000) + # lr_lambda = lambda i_iter: 0.01**(i_iter/opt.n_iters) optimizer = torch.optim.Adam([ {"params": gaussian_splatter.gaussian_3ds.opa, "lr": lr_opa * lr_lambda(0)}, {"params": gaussian_splatter.gaussian_3ds.rgb, "lr": lr_rgb * lr_lambda(0)}, @@ -45,9 +59,17 @@ def hook(grad): grads[name] = grad return hook + test_split = np.arange(0, n_cameras, 8) + train_split = np.array(list(set(np.arange(0, n_cameras, 1)) - set(test_split))) + for i_iter in bar: + _adaptive_control = i_iter >= 800 and i_iter % opt.n_adaptive_control == 0 optimizer.zero_grad() - camera_id = np.random.randint(0, n_cameras) + + # train split + camera_id = np.random.choice(train_split, 1)[0] + # camera_id = np.random.randint(0, n_cameras) + #rendered_img = gaussian_splatter(camera_id, record_view_space_pos_grad=(i_iter % opt.n_adaptive_control == 0 and i_iter>0)) rendered_img = gaussian_splatter(camera_id) # loss = ((rendered_img - gaussian_splatter.ground_truth)**2).mean() @@ -55,17 +77,15 @@ def hook(grad): ssim_loss = ssim_criterion(rendered_img.unsqueeze(0).permute(0, 3, 1, 2), gaussian_splatter.ground_truth.unsqueeze(0).permute(0, 3, 1, 2).to(rendered_img)) # ssim_loss = 0 #loss = 0.99*l1_loss + 0.01*ssim_loss - loss = 0.8*l1_loss + 0.2*ssim_loss + loss = (1-opt.ssim_weight)*l1_loss + opt.ssim_weight*ssim_loss if opt.scale_reg > 0: loss += opt.scale_reg * gaussian_splatter.gaussian_3ds.scale.abs().mean() handle = None - if i_iter % opt.n_adaptive_control == 0 and i_iter > 0: - # gaussian_splatter.gaussian_3ds_pos_camera_space_culled.register_hook(save_grad("grad")) - handle = gaussian_splatter.gaussian_3ds.pos.register_hook(save_grad("grad")) - else: - if handle is not None: - handle.remove() - #gaussian_splatter.gaussian_3ds.pos.register_hook(save_grad("grad")) + # if i_iter % opt.n_adaptive_control == 0 and i_iter > 0: + # handle = gaussian_splatter.gaussian_3ds.pos.register_hook(save_grad("grad")) + # else: + # if handle is not None: + # handle.remove() # ssim_loss = 0 with Timer("psnr"): @@ -85,59 +105,97 @@ def hook(grad): avg_l1_loss = l1_losses[:min(i_iter+1, l1_losses.shape[0])].mean() avg_ssim_loss = ssim_losses[:min(i_iter+1, ssim_losses.shape[0])].mean() avg_psnr = psnrs[:min(i_iter+1, psnrs.shape[0])].mean() - bar.set_description(desc=f"loss: {avg_l1_loss:.6f}/{avg_ssim_loss:.6f}/{avg_psnr:.4f}/[{gaussian_splatter.n_tile_gaussians}/{gaussian_splatter.n_gaussians}]: lr: {optimizer.param_groups[0]['lr']:.6f}") + + grad_info = [ + gaussian_splatter.gaussian_3ds.opa.grad.abs().mean(), + gaussian_splatter.gaussian_3ds.rgb.grad.abs().mean(), + gaussian_splatter.gaussian_3ds.pos.grad.abs().mean(), + gaussian_splatter.gaussian_3ds.scale.grad.abs().mean(), + gaussian_splatter.gaussian_3ds.quat.grad.abs().mean(), + ] + grad_desc = "[{:.6f}|{:.6f}|{:.6f}|{:.6f}|{:.6f}]".format(*grad_info) + bar.set_description(desc=f"loss: {avg_l1_loss:.6f}/{avg_ssim_loss:.6f}/{avg_psnr:.4f}/[{gaussian_splatter.n_tile_gaussians}/{gaussian_splatter.n_gaussians}]:" + + f"lr: {optimizer.param_groups[0]['lr']:.6f} " + + f"grad: {grad_desc}" + ) if i_iter % opt.n_save_train_img == 0: img_npy = rendered_img.clip(0,1).detach().cpu().numpy() cv2.imwrite(f"imgs/train_{i_iter}.png", (img_npy*255).astype(np.uint8)[...,::-1]) - # print(gaussian_splatter.gaussian_3ds.opa.grad.abs().mean()) - # print(gaussian_splatter.gaussian_3ds.rgb.grad.abs().mean()) - # print(gaussian_splatter.gaussian_3ds.pos.grad.abs().mean()) - # print(gaussian_splatter.gaussian_3ds.quat.grad.abs().mean()) - # print(gaussian_splatter.gaussian_3ds.scale.grad.abs().mean()) + # if i_iter == 500: + # print("="*100) + # print(gaussian_splatter.gaussian_3ds.opa.grad.abs().mean()) + # print(gaussian_splatter.gaussian_3ds.rgb.grad.abs().mean()) + # print(gaussian_splatter.gaussian_3ds.pos.grad.abs().mean()) + # print(gaussian_splatter.gaussian_3ds.quat.grad.abs().mean()) + # print(gaussian_splatter.gaussian_3ds.scale.grad.abs().mean()) + # exit() if i_iter % 100 == 0: Timer.show_recorder() - if i_iter % opt.n_adaptive_control == 0 and i_iter > 0: + if _adaptive_control: # adaptive control for gaussians - grad = grads["grad"] + grad = gaussian_splatter.gaussian_3ds.pos.grad # print(grad) adaptive_number = (grad.abs().max(-1)[0] > 0.0002).sum() adaptive_ratio = adaptive_number / grad[..., 0].numel() # print(adaptive_number, adaptive_ratio) - gaussian_splatter.gaussian_3ds.adaptive_control(grad, taus=opt.split_thresh, delete_thresh=opt.delete_thresh) + gaussian_splatter.gaussian_3ds.adaptive_control(grad, taus=opt.split_thresh, delete_thresh=opt.delete_thresh, scale_activation=gaussian_splatter.scale_activation) optimizer = torch.optim.Adam(gaussian_splatter.parameters(), lr=lr_lambda(0), betas=(0.9, 0.99)) - # if i_iter % (opt.n_adaptive_control*5) == 0 and i_iter > 0: - # gaussian_splatter.gaussian_3ds.reset_opa() + if i_iter % (opt.n_opa_reset) == 0 and i_iter > 0: + gaussian_splatter.gaussian_3ds.reset_opa() + + for i_opt, (param_group, lr) in enumerate(zip(optimizer.param_groups, lrs)): + if opt.adaptive_lr: + param_group['lr'] = min(lr_lambda(i_iter) * lrs[2] * grad_info[2]/grad_info[i_opt], 0.1) + else: + param_group['lr'] = lr_lambda(i_iter) * lr - for param_group, lr in zip(optimizer.param_groups, lrs): - param_group['lr'] = lr_lambda(i_iter) * lr + if i_iter % (opt.n_iters_test) == 0: + test_psnrs = [] + test_ssims = [] + for test_camera_id in test_split: + psnr, ssim = test(gaussian_splatter, test_camera_id) + test_psnrs.append(psnr) + test_ssims.apppend(ssims) + print(test_psnrs) + print(test_ssims) + print("TEST SPLIT PSNR: {:.4f}".format(np.mean(test_psnrs))) + print("TEST SPLIT SSIM: {:.4f}".format(np.mean(test_ssims))) + if __name__ == "__main__": # python train.py --render_downsample 2 --scale_init_value 0.01 --opa_init_value 0.5 --lr 0.001 parser = argparse.ArgumentParser() parser.add_argument("--n_iters", type=int, default=10000) + parser.add_argument("--n_iters_warmup", type=int, default=200) + parser.add_argument("--n_iters_test", type=int, default=200) parser.add_argument("--n_history_track", type=int, default=100) parser.add_argument("--n_save_train_img", type=int, default=100) - parser.add_argument("--n_adaptive_control", type=int, default=40000) - parser.add_argument("--render_downsample", type=int, default=2) + parser.add_argument("--n_adaptive_control", type=int, default=200) + parser.add_argument("--render_downsample", type=int, default=4) parser.add_argument("--jacobian_track", type=int, default=0) parser.add_argument("--data", type=str, default="garden") - parser.add_argument("--scale_init_value", type=float, default=0.02) - parser.add_argument("--opa_init_value", type=float, default=0.1) + parser.add_argument("--scale_init_value", type=float, default=1) + parser.add_argument("--opa_init_value", type=float, default=0.6) parser.add_argument("--tile_culling_dist_thresh", type=float, default=0.5) - parser.add_argument("--tile_culling_prob_thresh", type=float, default=0.1) - parser.add_argument("--tile_culling_method", type=str, default="dist", choices=["dist", "prob", "prob2"]) - parser.add_argument("--lr", type=float, default=0.01) + parser.add_argument("--tile_culling_prob_thresh", type=float, default=0.05) + parser.add_argument("--tile_culling_method", type=str, default="prob2", choices=["dist", "prob", "prob2"]) + parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--lr_factor_for_scale", type=float, default=1) parser.add_argument("--lr_factor_for_opa", type=float, default=1) parser.add_argument("--lr_factor_for_quat", type=float, default=1) parser.add_argument("--delete_thresh", type=float, default=1.5) + parser.add_argument("--n_opa_reset", type=int, default=10000000) parser.add_argument("--split_thresh", type=float, default=0.05) + parser.add_argument("--ssim_weight", type=float, default=0.2) parser.add_argument("--debug", type=int, default=1) parser.add_argument("--scale_reg", type=float, default=0) + parser.add_argument("--cudaculling", type=int, default=0) + parser.add_argument("--adaptive_lr", type=int, default=0) parser.add_argument("--seed", type=int, default=2023) + parser.add_argument("--scale_activation", type=str, default="abs", choices=["abs", "exp"]) opt = parser.parse_args() np.random.seed(opt.seed) if opt.jacobian_track: @@ -162,6 +220,8 @@ def hook(grad): tile_culling_dist_thresh=opt.tile_culling_dist_thresh, tile_culling_prob_thresh=opt.tile_culling_prob_thresh, debug=opt.debug, + scale_activation=opt.scale_activation, + cudaculling=opt.cudaculling, #jacobian_calc="torch", ) train(gaussian_splatter, opt)