Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
WangFeng18 committed Jun 16, 2023
1 parent 32f6786 commit 1abe284
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 146 deletions.
Binary file added .splatter.py.swp
Binary file not shown.
2 changes: 1 addition & 1 deletion renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(ctx, x):
@staticmethod
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(-3, 1))
return g * torch.exp(x.clamp(-1, 1))

trunc_exp = _trunc_exp.apply

Expand Down
212 changes: 125 additions & 87 deletions splatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def to_scale_matrix(self):

def _tocpp(self):
_cobj = gaussian.Gaussian3ds()
_cobj.pos = self.pos
_cobj.rgb = self.rgb
_cobj.opa = self.opa
_cobj.cov = self.cov
_cobj.pos = self.pos.clone()
_cobj.rgb = self.rgb.clone()
_cobj.opa = self.opa.clone()
_cobj.cov = self.cov.clone()
return _cobj

def to(self, *args, **kwargs):
Expand Down Expand Up @@ -93,12 +93,18 @@ def filte(self, mask):
cov=self.cov[mask],
)

def get_gaussian_3d_conv(self):
def get_gaussian_3d_cov(self, scale_activation="abs"):
R = q2r(self.quat)
_scale = self.scale.abs()+1e-4
if scale_activation == "abs":
_scale = self.scale.abs()+1e-4
elif scale_activation == "exp":
_scale = trunc_exp(self.scale)
else:
print("No support scale activation")
exit()
# _scale = trunc_exp(self.scale)
# _scale = torch.clamp(_scale, min=1e-4, max=0.1)
S = torch.diag_embed(_scale)
# S = torch.diag_embed(trunc_exp(self.scale))
RS = torch.bmm(R, S)
RSSR = torch.bmm(RS, RS.permute(0,2,1))
return RSSR
Expand All @@ -109,15 +115,22 @@ def get_gaussian_3d_conv(self):
def reset_opa(self):
torch.nn.init.uniform_(self.opa, a=inverse_sigmoid(0.1), b=inverse_sigmoid(0.11))

def adaptive_control(self, grad, taus, delete_thresh):
def adaptive_control(self, grad, taus, delete_thresh, scale_activation="abs"):
# grad: B x 3
# densification
# 1. delete gaussians with small opacities
assert self.init_values # only for the initial gaussians
print(inverse_sigmoid(0.1))
print(inverse_sigmoid(0.01))
print(self.opa.min())
print(self.opa.max())
_mask = (self.opa > inverse_sigmoid(0.1)) & (self.scale.norm(dim=-1) < delete_thresh)
if scale_activation == "abs":
_mask = (self.opa > inverse_sigmoid(0.01)) & (self.scale.norm(dim=-1) < delete_thresh)
elif scale_activation == "exp":
_mask = (self.opa > inverse_sigmoid(0.01)) & (self.scale.exp().norm(dim=-1) < delete_thresh)
else:
print("Wrong activation")
exit()

self.pos = nn.parameter.Parameter(self.pos.detach()[_mask])
self.rgb = nn.parameter.Parameter(self.rgb.detach()[_mask])
self.opa = nn.parameter.Parameter(self.opa.detach()[_mask])
Expand All @@ -127,21 +140,21 @@ def adaptive_control(self, grad, taus, delete_thresh):
print("DELETE: {} Gaussians".format((~_mask).sum()))
# 2. clone or split
densify_mask = grad.abs().max(-1)[0] > 0.0002
cat_pos = [self.pos.detach()]
cat_rgb = [self.rgb.detach()]
cat_opa = [self.opa.detach()]
cat_quat = [self.quat.detach()]
cat_scale = [self.scale.detach()]
cat_pos = [self.pos.clone().detach()]
cat_rgb = [self.rgb.clone().detach()]
cat_opa = [self.opa.clone().detach()]
cat_quat = [self.quat.clone().detach()]
cat_scale = [self.scale.clone().detach()]
if densify_mask.any():
scale_norm = self.scale.norm(dim=-1)
scale_norm = self.scale.norm(dim=-1) if scale_activation == "abs" else self.scale.exp().norm(dim=-1)
split_mask = scale_norm > taus
clone_mask = scale_norm <= taus
split_mask = split_mask & densify_mask
clone_mask = clone_mask & densify_mask
if clone_mask.any():
cloned_pos = self.pos[clone_mask].clone().detach()
# cloned_pos -= grad[clone_mask] * 0.01
cloned_pos -= grad[clone_mask] * 10
cloned_pos -= grad[clone_mask] * 0.01
cloned_rgb = self.rgb[clone_mask].clone().detach()
cloned_opa = self.opa[clone_mask].clone().detach()
cloned_quat = self.quat[clone_mask].clone().detach()
Expand All @@ -154,24 +167,33 @@ def adaptive_control(self, grad, taus, delete_thresh):
cat_scale.append(cloned_scale)

if split_mask.any():
_scale = self.scale.detach()
_scale[split_mask] /= 1.6

cat_scale[0][split_mask] /= 1.6
self.scale = nn.parameter.Parameter(_scale)
_scale = self.scale.clone().detach()
if scale_activation == "abs":
_scale[split_mask] /= 1.6
elif scale_activation == "exp":
_scale[split_mask] -= math.log(1.6)
else:
print("Wrong activation")
exit()

cat_scale[0] = _scale
# cat_scale[0][split_mask] /= 1.6
# self.scale = nn.parameter.Parameter(_scale)

# sampling two positions
this_cov = self.get_gaussian_3d_conv()[split_mask]
this_cov = self.get_gaussian_3d_cov(scale_activation=scale_activation)[split_mask]
p1, p2 = sample_two_point(self.pos[split_mask], this_cov)

# split_pos = self.pos[split_mask].clone().detach()
# split_pos -= grad[split_mask] * 0.01
cat_pos[0][split_mask] = p1.detach()
origin_pos = cat_pos[0]
origin_pos[split_mask] = p1.detach()
cat_pos[0] = origin_pos
split_pos = p2.detach()
split_rgb = self.rgb[split_mask].clone().detach()
split_opa = self.opa[split_mask].clone().detach()
split_quat = self.quat[split_mask].clone().detach()
split_scale = self.scale[split_mask].clone().detach()
split_scale = _scale[split_mask].clone()
print("SPLIT : {} Gaussians".format(split_pos.shape[0]))
cat_pos.append(split_pos)
cat_rgb.append(split_rgb)
Expand All @@ -184,7 +206,7 @@ def adaptive_control(self, grad, taus, delete_thresh):
self.quat = nn.parameter.Parameter(torch.cat(cat_quat))
self.scale = nn.parameter.Parameter(torch.cat(cat_scale))

def project(self, rot, tran, near, jacobian_calc):
def project(self, rot, tran, near, jacobian_calc, scale_activation="abs"):

with Timer(" w2c"):
pos_cam_space = world_to_camera(self.pos, rot, tran)
Expand All @@ -198,9 +220,12 @@ def project(self, rot, tran, near, jacobian_calc):
else:
jacobian = jacobian_torch(pos_cam_space)

with Timer(" cov"):
gaussian_3d_cov = self.get_gaussian_3d_conv()
JW = torch.einsum("bij,bjk->bik", jacobian, rot.unsqueeze(dim=0))
with Timer(" cov1"):
gaussian_3d_cov = self.get_gaussian_3d_cov(scale_activation=scale_activation)
# JW = torch.einsum("bij,bjk->bik", jacobian, rot.unsqueeze(dim=0))
with Timer(" cov2"):
JW = torch.matmul(jacobian, rot.unsqueeze(dim=0))
with Timer(" cov3"):
gaussian_2d_cov = torch.bmm(torch.bmm(JW, gaussian_3d_cov), JW.permute(0,2,1))[:, :2, :2]

with Timer(" last"):
Expand Down Expand Up @@ -229,8 +254,8 @@ def crop(self, image):
# output: height x width x 3
top = int(self.padded_height - self.height)//2
left = int(self.padded_width - self.width)//2
# return image[top:top+int(self.height), left:left+int(self.width), :]
return image[top:top+int(self.height)-1, left:left+int(self.width)-1, :]
#return image[top:top+int(self.height), left:left+int(self.width), :]
return image[top:top+int(self.height)-1, left:left+int(self.width), :]

def create_tiles(self):
self.tiles_left = torch.linspace(-self.padded_width/2, self.padded_width/2, self.n_tile_x + 1, device=self.device)[:-1]
Expand Down Expand Up @@ -263,7 +288,6 @@ def create_tiles(self):
def __len__(self):
return self.tiles_top.shape[0]


class Splatter(nn.Module):
def __init__(self,
colmap_path,
Expand All @@ -280,6 +304,8 @@ def __init__(self,
tile_culling_dist_thresh=0.5,
tile_culling_prob_thresh=0.1,
debug=1,
scale_activation="abs",
cudaculling=0,
):
super().__init__()
self.device = torch.device("cuda")
Expand All @@ -292,6 +318,8 @@ def __init__(self,
self.tile_culling_dist_thresh = tile_culling_dist_thresh
self.tile_culling_prob_thresh = tile_culling_prob_thresh
self.debug = debug
self.scale_activation = scale_activation
self.cudaculling = cudaculling
assert jacobian_calc == "cuda" or jacobian_calc == "torch"

self.points3d = read_points3d_binary(os.path.join(colmap_path, "points3D.bin"))
Expand All @@ -313,13 +341,26 @@ def __init__(self,
rgb = torch.stack(_rgbs).to(torch.float32).to(self.device) # B x 3
if self.use_sh_coeff:
rgb = initialize_sh(rgb)

_pos=torch.stack(_points).to(torch.float32).to(self.device)
mean_min_three_dis = []
for i_pos in tqdm(range(_pos.shape[0])):
_r = (_pos[i_pos:i_pos+1] - _pos).norm(dim=-1).sort(dim=-1)[0][1:4].mean().item()
mean_min_three_dis.append(_r)
mean_min_three_dis = torch.Tensor(mean_min_three_dis).to(torch.float32) * scale_init_value

if scale_activation == "exp":
mean_min_three_dis = mean_min_three_dis.log()

self.gaussian_3ds = Gaussian3ds(
pos=torch.stack(_points).to(torch.float32).to(self.device), # B x 3
pos=_pos.to(self.device), # B x 3
rgb = rgb, # B x 3 or 27
opa = torch.ones(len(_points)).to(torch.float32).to(self.device)*inverse_sigmoid(opa_init_value), # B
quat = torch.Tensor([1, 0, 0, 0]).unsqueeze(dim=0).repeat(len(_points),1).to(torch.float32).to(self.device), # B x 4
# quat = torch.Tensor([1, 1, 2, 1]).unsqueeze(dim=0).repeat(len(_points),1).to(torch.float32).to(self.device), # B x 4
scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*scale_init_value,
#scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*math.log(scale_init_value),
#scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*scale_init_value,
scale = torch.ones(len(_points), 3).to(torch.float32).to(self.device)*mean_min_three_dis.unsqueeze(dim=1).to(self.device),
init_values=True,
)
self.current_camera = None
Expand Down Expand Up @@ -372,65 +413,61 @@ def set_camera(self, idx):
# print("current_camera info")
# print(self.current_camera)

def project_and_culling(self, cudaculling=False):
def project_and_culling(self):
# project 3D to 2D
# print(f"number of gaussians {len(self.gaussian_3ds.pos)}")
# self.tic.record()
with Timer(" frustum pytorch", verbose=True):
with Timer(" frustum 11"):
gaussian_3ds_pos_camera_space = world_to_camera(self.gaussian_3ds.pos, self.current_w2c_rot, self.current_w2c_tran)
with Timer(" frustum 12"):
valid = gaussian_3ds_pos_camera_space[:,2] > self.near
gaussian_3ds_pos_image_space = camera_to_image(gaussian_3ds_pos_camera_space[valid])
culling_mask = (gaussian_3ds_pos_image_space[:, 0].abs() < (self.current_camera.width/2/self.current_camera.params[0])) & \
(gaussian_3ds_pos_image_space[:, 1].abs() < (self.current_camera.height/2/self.current_camera.params[1]))
with Timer(" frustum 13"):
self.gaussian_3ds_valid = self.gaussian_3ds.filte(valid)#.filte(culling_mask)
with Timer(" frustum 2"):
self.culling_gaussian_3d_image_space = self.gaussian_3ds_valid.project(self.current_w2c_rot, self.current_w2c_tran, self.near, self.jacobian_calc).filte(culling_mask)
# self.culling_gaussian_3d_image_space = self.gaussian_3ds.project(self.current_w2c_rot, self.current_w2c_tran, self.near, self.jacobian_calc).filte(valid).filte(culling_mask)
# culling
# with Timer(" frumstum 3"):
# self.culling_gaussian_3d_image_space = self.gaussian_3ds_image_space.filte(culling_mask)
# with Timer("step11"):
# pos_camera_space = world_to_camera(self.gaussian_3ds.pos, self.current_w2c_rot, self.current_w2c_tran)
# with Timer("step12"):
# valid = pos_camera_space[:,2] > self.near
# with Timer("step13"):
# pos_img_space = camera_to_image(pos_camera_space)
# with Timer("step14"):
# culling_mask = (pos_img_space[:, 0].abs() < (self.current_camera.width/2/self.current_camera.params[0])) & \
# (pos_img_space[:, 1].abs() < (self.current_camera.height/2/self.current_camera.params[1]))
# with Timer("step15"):
# # calc 2d cov
# valid = valid & culling_mask
# print(valid.sum(0))

# with Timer("step2"):
# jacobian = torch.empty(valid.sum(0), 3, 3, device=self.device)
# gaussian.jacobian(pos_camera_space[valid], jacobian)

# R = q2r(self.gaussian_3ds.quat[valid])
# _scale = (self.gaussian_3ds.scale[valid]).abs()+1e-4
# S = torch.diag_embed(_scale)
# RS = torch.bmm(R, S)
# RSSR = torch.bmm(RS, RS.permute(0,2,1))

# JW = torch.einsum("bij,bjk->bik", jacobian, self.current_w2c_rot.unsqueeze(dim=0))
# cov_2d = torch.bmm(torch.bmm(JW, RSSR), JW.permute(0,2,1))[:, :2, :2]

# with Timer("step3"):
# self.culling_gaussian_3d_image_space = Gaussian3ds(
# pos=pos_img_space[valid],
# rgb=self.gaussian_3ds.rgb.sigmoid()[valid],
# opa=self.gaussian_3ds.opa.sigmoid()[valid],
# cov=cov_2d,
# )
if self.cudaculling:
with Timer(" frustum cuda"):
# self.gaussian_3ds.rgb = self.gaussian_3ds.rgb.sigmoid()
# self.gaussian_3ds.opa = self.gaussian_3ds.opa.sigmoid()
# self.gaussian_3ds.quat = self.gaussian_3ds.quat / self.gaussian_3ds.quat.norm(dim=-1, keepdim=True)
# self.gaussian_3ds.scale = self.gaussian_3ds.scale.abs()+1e-5
normed_quat = (self.gaussian_3ds.quat/self.gaussian_3ds.quat.norm(dim=-1, keepdim=True))
if self.scale_activation == "abs":
normed_scale = self.gaussian_3ds.scale.abs()+1e-6
else:
assert self.scale_activation == "exp"
normed_scale = trunc_exp(self.gaussian_3ds.scale)
_pos, _cov, _culling_mask = global_culling(
self.gaussian_3ds.pos,
normed_quat,
normed_scale,
self.current_w2c_rot.detach(),
self.current_w2c_tran.detach(),
self.near,
self.current_camera.width/2/self.current_camera.params[0],
self.current_camera.height/2/self.current_camera.params[1],
)

self.culling_gaussian_3d_image_space = Gaussian3ds(
pos=_pos[_culling_mask.bool()],
cov=_cov[_culling_mask.bool()],
rgb=self.gaussian_3ds.rgb[_culling_mask.bool()].sigmoid(),
opa=self.gaussian_3ds.opa[_culling_mask.bool()].sigmoid(),
)
# print(len(self.culling_gaussian_3d_image_space.pos))
else:
gaussian_3ds_pos_camera_space = world_to_camera(self.gaussian_3ds.pos, self.current_w2c_rot, self.current_w2c_tran)
valid = gaussian_3ds_pos_camera_space[:,2] > self.near
gaussian_3ds_pos_image_space = camera_to_image(gaussian_3ds_pos_camera_space)
culling_mask = (gaussian_3ds_pos_image_space[:, 0].abs() < (self.current_camera.width*1.2/2/self.current_camera.params[0])) & \
(gaussian_3ds_pos_image_space[:, 1].abs() < (self.current_camera.height*1.2/2/self.current_camera.params[1]))
valid &= culling_mask
self.gaussian_3ds_valid = self.gaussian_3ds.filte(valid)
self.culling_gaussian_3d_image_space = self.gaussian_3ds_valid.project(
self.current_w2c_rot,
self.current_w2c_tran,
self.near,
self.jacobian_calc,
scale_activation=self.scale_activation,
)

def render(self, out_write=True):
# self.tic.record()
with Timer(" culling tiles", debug=self.debug):
tile_n_point = torch.zeros(len(self.tile_info), device=self.device, dtype=torch.int32)
# print("***{}***".format(len(self.culling_gaussian_3d_image_space.pos)//10))
tile_gaussian_list = torch.ones(len(self.tile_info), len(self.culling_gaussian_3d_image_space.pos)//10, device=self.device, dtype=torch.int32) * -1
_method_config = {"dist": 0, "prob": 1, "prob2": 2}
gaussian.calc_tile_list(
Expand All @@ -450,6 +487,7 @@ def render(self, out_write=True):
self.tile_info.topmost,
# 0.01,
)
# print(tile_gaussian_list.sum())

with Timer(" gather culled tiles", debug=self.debug):
gathered_list = torch.empty(tile_n_point.sum(), dtype=torch.int32, device=self.device)
Expand Down Expand Up @@ -498,14 +536,14 @@ def render(self, out_write=True):

return rendered_image

def forward(self, camera_id=None, record_view_space_pos_grad=False, cudaculling=False):
def forward(self, camera_id=None, record_view_space_pos_grad=False):
# print(self.gaussian_3ds.opa.max())
# print(self.gaussian_3ds.opa.min())
with Timer("forward", debug=self.debug):
with Timer("set camera"):
self.set_camera(camera_id)
with Timer("frustum culling", debug=self.debug):
self.project_and_culling(cudaculling)
self.project_and_culling()
with Timer("render function", debug=self.debug):
padded_render_img = self.render(out_write=False)
with Timer("crop", debug=self.debug):
Expand Down
5 changes: 2 additions & 3 deletions src/gaussian.cu
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ void gather_gaussians(
max_points_for_tile,
gaussian_list_size
);

}

template<uint32_t SMSIZE>
Expand Down Expand Up @@ -470,7 +469,7 @@ __global__ void draw_backward_kernel(
global_idx = i_loadings*SMSIZE + i;

// check bound and early stop
if(global_idx>=(end_idx-start_idx)||accum < 0.01){
if(global_idx>=(end_idx-start_idx)||accum < 0.0001){
break;
}
_a = _gaussian_cov[i*4+0];
Expand Down Expand Up @@ -666,7 +665,7 @@ __global__ void draw_kernel(
global_idx = i_loadings*SMSIZE + i;

// check bound and early stop
if(global_idx>=(end_idx-start_idx)||accum < 0.01){
if(global_idx>=(end_idx-start_idx)||accum < 0.0001){
break;
}
_a = _gaussian_cov[i*4+0];
Expand Down
Loading

0 comments on commit 1abe284

Please sign in to comment.