diff --git a/nerfacc/volrend.py b/nerfacc/volrend.py index 07dbff88..62bc42c5 100644 --- a/nerfacc/volrend.py +++ b/nerfacc/volrend.py @@ -155,6 +155,7 @@ def render_transmittance_from_alpha( packed_info: Optional[Tensor] = None, ray_indices: Optional[Tensor] = None, n_rays: Optional[int] = None, + prefix_trans: Optional[Tensor] = None, ) -> Tensor: """Compute transmittance :math:`T_i` from alpha :math:`\\alpha_i`. @@ -171,6 +172,7 @@ def render_transmittance_from_alpha( Useful for flattened input. ray_indices: Ray indices of the flattened samples. LongTensor with shape (all_samples). n_rays: Number of rays. Only useful when `ray_indices` is provided. + prefix_trans: The pre-computed transmittance of the samples. Tensor with shape (all_samples,). Returns: The rendering transmittance with the same shape as `alphas`. @@ -191,6 +193,8 @@ def render_transmittance_from_alpha( packed_info = pack_info(ray_indices, n_rays) trans = exclusive_prod(1 - alphas, packed_info) + if prefix_trans is not None: + trans *= prefix_trans return trans @@ -201,6 +205,7 @@ def render_transmittance_from_density( packed_info: Optional[Tensor] = None, ray_indices: Optional[Tensor] = None, n_rays: Optional[int] = None, + prefix_trans: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Compute transmittance :math:`T_i` from density :math:`\\sigma_i`. @@ -221,6 +226,7 @@ def render_transmittance_from_density( Useful for flattened input. ray_indices: Ray indices of the flattened samples. LongTensor with shape (all_samples). n_rays: Number of rays. Only useful when `ray_indices` is provided. + prefix_trans: The pre-computed transmittance of the samples. Tensor with shape (all_samples,). Returns: The rendering transmittance and opacities, both with the same shape as `sigmas`. @@ -245,6 +251,8 @@ def render_transmittance_from_density( sigmas_dt = sigmas * (t_ends - t_starts) alphas = 1.0 - torch.exp(-sigmas_dt) trans = torch.exp(-exclusive_sum(sigmas_dt, packed_info)) + if prefix_trans is not None: + trans *= prefix_trans return trans, alphas @@ -253,6 +261,7 @@ def render_weight_from_alpha( packed_info: Optional[Tensor] = None, ray_indices: Optional[Tensor] = None, n_rays: Optional[int] = None, + prefix_trans: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Compute rendering weights :math:`w_i` from opacity :math:`\\alpha_i`. @@ -269,6 +278,7 @@ def render_weight_from_alpha( Useful for flattened input. ray_indices: Ray indices of the flattened samples. LongTensor with shape (all_samples). n_rays: Number of rays. Only useful when `ray_indices` is provided. + prefix_trans: The pre-computed transmittance of the samples. Tensor with shape (all_samples,). Returns: The rendering weights and transmittance, both with the same shape as `alphas`. @@ -285,7 +295,7 @@ def render_weight_from_alpha( """ trans = render_transmittance_from_alpha( - alphas, packed_info, ray_indices, n_rays + alphas, packed_info, ray_indices, n_rays, prefix_trans ) weights = trans * alphas return weights, trans @@ -298,6 +308,7 @@ def render_weight_from_density( packed_info: Optional[Tensor] = None, ray_indices: Optional[Tensor] = None, n_rays: Optional[int] = None, + prefix_trans: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor]: """Compute rendering weights :math:`w_i` from density :math:`\\sigma_i` and interval :math:`\\delta_i`. @@ -316,6 +327,7 @@ def render_weight_from_density( Useful for flattened input. ray_indices: Ray indices of the flattened samples. LongTensor with shape (all_samples). n_rays: Number of rays. Only useful when `ray_indices` is provided. + prefix_trans: The pre-computed transmittance of the samples. Tensor with shape (all_samples,). Returns: The rendering weights, transmittance and opacities, both with the same shape as `sigmas`. @@ -336,7 +348,7 @@ def render_weight_from_density( """ trans, alphas = render_transmittance_from_density( - t_starts, t_ends, sigmas, packed_info, ray_indices, n_rays + t_starts, t_ends, sigmas, packed_info, ray_indices, n_rays, prefix_trans ) weights = trans * alphas return weights, trans, alphas @@ -350,6 +362,7 @@ def render_visibility_from_alpha( n_rays: Optional[int] = None, early_stop_eps: float = 1e-4, alpha_thre: float = 0.0, + prefix_trans: Optional[Tensor] = None, ) -> Tensor: """Compute visibility from opacity :math:`\\alpha_i`. @@ -370,6 +383,7 @@ def render_visibility_from_alpha( n_rays: Number of rays. Only useful when `ray_indices` is provided. early_stop_eps: The early stopping threshold on transmittance. alpha_thre: The threshold on opacity. + prefix_trans: The pre-computed transmittance of the samples. Tensor with shape (all_samples,). Returns: A boolean tensor indicating which samples are visible. Same shape as `alphas`. @@ -388,7 +402,7 @@ def render_visibility_from_alpha( """ trans = render_transmittance_from_alpha( - alphas, packed_info, ray_indices, n_rays + alphas, packed_info, ray_indices, n_rays, prefix_trans ) vis = trans >= early_stop_eps if alpha_thre > 0: @@ -406,6 +420,7 @@ def render_visibility_from_density( n_rays: Optional[int] = None, early_stop_eps: float = 1e-4, alpha_thre: float = 0.0, + prefix_trans: Optional[Tensor] = None, ) -> Tensor: """Compute visibility from density :math:`\\sigma_i` and interval :math:`\\delta_i`. @@ -426,6 +441,7 @@ def render_visibility_from_density( n_rays: Number of rays. Only useful when `ray_indices` is provided. early_stop_eps: The early stopping threshold on transmittance. alpha_thre: The threshold on opacity. + prefix_trans: The pre-computed transmittance of the samples. Tensor with shape (all_samples,). Returns: A boolean tensor indicating which samples are visible. Same shape as `alphas`. @@ -448,7 +464,7 @@ def render_visibility_from_density( """ trans, alphas = render_transmittance_from_density( - t_starts, t_ends, sigmas, packed_info, ray_indices, n_rays + t_starts, t_ends, sigmas, packed_info, ray_indices, n_rays, prefix_trans ) vis = trans >= early_stop_eps if alpha_thre > 0: