Skip to content

Commit

Permalink
Support setting alpha threshold for marching and rendering (#42)
Browse files Browse the repository at this point in the history
* support alpha_thre for rendering and ray marching. default to zero

* bump version
  • Loading branch information
liruilong940607 authored Oct 4, 2022
1 parent daf3559 commit 542f431
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
copyright = "2022, Ruilong"
author = "Ruilong"

release = "0.1.2"
version = "0.1.2"
release = "0.1.4"
version = "0.1.4"

# -- General configuration

Expand Down
8 changes: 6 additions & 2 deletions nerfacc/cuda/csrc/pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ std::vector<torch::Tensor> rendering_forward(
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps,
float alpha_thre,
bool compression);

torch::Tensor rendering_backward(
Expand All @@ -17,7 +18,8 @@ torch::Tensor rendering_backward(
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps);
float early_stop_eps,
float alpha_thre);

std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
Expand Down Expand Up @@ -65,12 +67,14 @@ torch::Tensor rendering_alphas_backward(
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps);
float early_stop_eps,
float alpha_thre);

std::vector<torch::Tensor> rendering_alphas_forward(
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps,
float alpha_thre,
bool compression);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
Expand Down
31 changes: 29 additions & 2 deletions nerfacc/cuda/csrc/rendering.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __global__ void rendering_forward_kernel(
const scalar_t *sigmas, // input density after activation
const scalar_t *alphas, // input alpha (opacity) values.
const scalar_t early_stop_eps, // transmittance threshold for early stop
const scalar_t alpha_thre, // alpha threshold for emtpy space
// outputs: should be all-zero initialized
int *num_steps, // the number of valid steps for each ray
scalar_t *weights, // the number rendering weights for each sample
Expand Down Expand Up @@ -70,6 +71,11 @@ __global__ void rendering_forward_kernel(
scalar_t delta = ends[j] - starts[j];
alpha = 1.f - __expf(-sigmas[j] * delta);
}
if (alpha < alpha_thre)
{
// empty space
continue;
}
const scalar_t weight = alpha * T;
T *= (1.f - alpha);
if (weights != nullptr)
Expand Down Expand Up @@ -97,6 +103,7 @@ __global__ void rendering_backward_kernel(
const scalar_t *sigmas, // input density after activation
const scalar_t *alphas, // input alpha (opacity) values.
const scalar_t early_stop_eps, // transmittance threshold for early stop
const scalar_t alpha_thre, // alpha threshold for emtpy space
const scalar_t *weights, // forward output
const scalar_t *grad_weights, // input gradients
// if alphas was given, we compute the gradients for alphas.
Expand Down Expand Up @@ -150,13 +157,23 @@ __global__ void rendering_backward_kernel(
{
// rendering with alpha
alpha = alphas[j];
if (alpha < alpha_thre)
{
// empty space
continue;
}
grad_alphas[j] = (grad_weights[j] * T - accum) / fmaxf(1.f - alpha, 1e-10f);
}
else
{
// rendering with density
scalar_t delta = ends[j] - starts[j];
alpha = 1.f - __expf(-sigmas[j] * delta);
if (alpha < alpha_thre)
{
// empty space
continue;
}
grad_sigmas[j] = (grad_weights[j] * T - accum) * delta;
}

Expand All @@ -171,6 +188,7 @@ std::vector<torch::Tensor> rendering_forward(
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps,
float alpha_thre,
bool compression)
{
DEVICE_GUARD(packed_info);
Expand Down Expand Up @@ -211,6 +229,7 @@ std::vector<torch::Tensor> rendering_forward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
// outputs
num_steps.data_ptr<int>(),
nullptr,
Expand Down Expand Up @@ -238,6 +257,7 @@ std::vector<torch::Tensor> rendering_forward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
// outputs
nullptr,
weights.data_ptr<scalar_t>(),
Expand All @@ -254,7 +274,8 @@ torch::Tensor rendering_backward(
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
float early_stop_eps)
float early_stop_eps,
float alpha_thre)
{
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
Expand All @@ -279,6 +300,7 @@ torch::Tensor rendering_backward(
sigmas.data_ptr<scalar_t>(),
nullptr, // alphas
early_stop_eps,
alpha_thre,
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
// outputs
Expand All @@ -295,6 +317,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps,
float alpha_thre,
bool compression)
{
DEVICE_GUARD(packed_info);
Expand Down Expand Up @@ -331,6 +354,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
// outputs
num_steps.data_ptr<int>(),
nullptr,
Expand Down Expand Up @@ -358,6 +382,7 @@ std::vector<torch::Tensor> rendering_alphas_forward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
// outputs
nullptr,
weights.data_ptr<scalar_t>(),
Expand All @@ -372,7 +397,8 @@ torch::Tensor rendering_alphas_backward(
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor alphas,
float early_stop_eps)
float early_stop_eps,
float alpha_thre)
{
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
Expand All @@ -397,6 +423,7 @@ torch::Tensor rendering_alphas_backward(
nullptr, // sigmas
alphas.data_ptr<scalar_t>(),
early_stop_eps,
alpha_thre,
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
// outputs
Expand Down
4 changes: 3 additions & 1 deletion nerfacc/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def rendering(
t_ends: torch.Tensor,
# rendering options
early_stop_eps: float = 1e-4,
alpha_thre: float = 1e-2,
render_bkgd: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Render the rays through the radience field defined by `rgb_sigma_fn`.
Expand All @@ -33,6 +34,7 @@ def rendering(
t_starts: Per-sample start distance. Tensor with shape (n_samples, 1).
t_ends: Per-sample end distance. Tensor with shape (n_samples, 1).
early_stop_eps: Early stop threshold during trasmittance accumulation. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
render_bkgd: Optional. Background color. Tensor with shape (3,).
Returns:
Expand Down Expand Up @@ -82,7 +84,7 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices):

# Rendering: compute weights and ray indices.
weights = render_weight_from_density(
packed_info, t_starts, t_ends, sigmas, early_stop_eps
packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre
)

# Rendering: accumulate rgbs, opacities, and depths along the rays.
Expand Down
4 changes: 3 additions & 1 deletion nerfacc/ray_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def ray_marching(
# sigma function for skipping invisible space
sigma_fn: Optional[Callable] = None,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
# rendering options
near_plane: Optional[float] = None,
far_plane: Optional[float] = None,
Expand Down Expand Up @@ -140,6 +141,7 @@ def ray_marching(
function that takes in samples {t_starts (N, 1), t_ends (N, 1),
ray indices (N,)} and returns the post-activation density values (N, 1).
early_stop_eps: Early stop threshold for skipping invisible space. Default: 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
near_plane: Optional. Near plane distance. If provided, it will be used
to clip t_min.
far_plane: Optional. Far plane distance. If provided, it will be used
Expand Down Expand Up @@ -272,7 +274,7 @@ def ray_marching(

# Compute visibility of the samples, and filter out invisible samples
visibility, packed_info_visible = render_visibility(
packed_info, alphas, early_stop_eps
packed_info, alphas, early_stop_eps, alpha_thre
)
t_starts, t_ends = t_starts[visibility], t_ends[visibility]
packed_info = packed_info_visible
Expand Down
29 changes: 24 additions & 5 deletions nerfacc/vol_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def render_weight_from_density(
t_ends,
sigmas,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
) -> torch.Tensor:
"""Compute transmittance weights from density.
Expand All @@ -94,6 +95,7 @@ def render_weight_from_density(
shape (n_samples, 1).
sigmas: The density values of the samples. Tensor with shape (n_samples, 1).
early_stop_eps: The epsilon value for early stopping. Default is 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
Returns:
transmittance weights with shape (n_samples,).
Expand Down Expand Up @@ -123,7 +125,7 @@ def render_weight_from_density(
if not sigmas.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
weights = _RenderingDensity.apply(
packed_info, t_starts, t_ends, sigmas, early_stop_eps
packed_info, t_starts, t_ends, sigmas, early_stop_eps, alpha_thre
)
return weights

Expand All @@ -132,6 +134,7 @@ def render_weight_from_alpha(
packed_info,
alphas,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
) -> Tuple[torch.Tensor, ...]:
"""Compute transmittance weights from density.
Expand All @@ -140,7 +143,8 @@ def render_weight_from_alpha(
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
alphas: The opacity values of the samples. Tensor with shape (n_samples, 1).
early_stop_eps: The epsilon value for early stopping. Default is 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
Returns:
transmittance weights with shape (n_samples,).
Expand Down Expand Up @@ -168,7 +172,9 @@ def render_weight_from_alpha(
"""
if not alphas.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
weights = _RenderingAlpha.apply(packed_info, alphas, early_stop_eps)
weights = _RenderingAlpha.apply(
packed_info, alphas, early_stop_eps, alpha_thre
)
return weights


Expand All @@ -177,6 +183,7 @@ def render_visibility(
packed_info: torch.Tensor,
alphas: torch.Tensor,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Filter out invisible samples given alpha (opacity).
Expand All @@ -185,6 +192,7 @@ def render_visibility(
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
alphas: The opacity values of the samples. Tensor with shape (n_samples, 1).
early_stop_eps: The epsilon value for early stopping. Default is 1e-4.
alpha_thre: Alpha threshold for skipping empty space. Default: 0.0.
Returns:
A tuple of tensors.
Expand Down Expand Up @@ -223,6 +231,7 @@ def render_visibility(
packed_info.contiguous(),
alphas.contiguous(),
early_stop_eps,
alpha_thre,
True, # compute visibility instead of weights
)
return visibility, packed_info_visible
Expand All @@ -239,6 +248,7 @@ def forward(
t_ends,
sigmas,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
):
packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous()
Expand All @@ -250,6 +260,7 @@ def forward(
t_ends,
sigmas,
early_stop_eps,
alpha_thre,
False, # not doing filtering
)[0]
if ctx.needs_input_grad[3]: # sigmas
Expand All @@ -261,12 +272,14 @@ def forward(
weights,
)
ctx.early_stop_eps = early_stop_eps
ctx.alpha_thre = alpha_thre
return weights

@staticmethod
def backward(ctx, grad_weights):
grad_weights = grad_weights.contiguous()
early_stop_eps = ctx.early_stop_eps
alpha_thre = ctx.alpha_thre
(
packed_info,
t_starts,
Expand All @@ -282,8 +295,9 @@ def backward(ctx, grad_weights):
t_ends,
sigmas,
early_stop_eps,
alpha_thre,
)
return None, None, None, grad_sigmas, None
return None, None, None, grad_sigmas, None, None


class _RenderingAlpha(torch.autograd.Function):
Expand All @@ -295,13 +309,15 @@ def forward(
packed_info,
alphas,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
):
packed_info = packed_info.contiguous()
alphas = alphas.contiguous()
weights = _C.rendering_alphas_forward(
packed_info,
alphas,
early_stop_eps,
alpha_thre,
False, # not doing filtering
)[0]
if ctx.needs_input_grad[1]: # alphas
Expand All @@ -311,12 +327,14 @@ def forward(
weights,
)
ctx.early_stop_eps = early_stop_eps
ctx.alpha_thre = alpha_thre
return weights

@staticmethod
def backward(ctx, grad_weights):
grad_weights = grad_weights.contiguous()
early_stop_eps = ctx.early_stop_eps
alpha_thre = ctx.alpha_thre
(
packed_info,
alphas,
Expand All @@ -328,5 +346,6 @@ def backward(ctx, grad_weights):
packed_info,
alphas,
early_stop_eps,
alpha_thre,
)
return None, grad_sigmas, None
return None, grad_sigmas, None, None
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "nerfacc"
version = "0.1.3"
version = "0.1.4"
authors = [{name = "Ruilong", email = "[email protected]"}]
license = { text="MIT" }
requires-python = ">=3.8"
Expand Down

0 comments on commit 542f431

Please sign in to comment.