Skip to content

Commit

Permalink
Distortion loss (#58)
Browse files Browse the repository at this point in the history
* add distortion loss

* update init

* bump version
  • Loading branch information
liruilong940607 authored Oct 7, 2022
1 parent 601c4c3 commit 3d95832
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 1 deletion.
2 changes: 2 additions & 0 deletions nerfacc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .contraction import ContractionType, contract, contract_inv
from .grid import Grid, OccupancyGrid, query_grid
from .intersection import ray_aabb_intersect
from .losses import distortion as loss_distortion
from .pack import pack_data, unpack_data, unpack_info
from .ray_marching import ray_marching
from .version import __version__
Expand Down Expand Up @@ -48,4 +49,5 @@ def unpack_to_ray_indices(*args, **kwargs):
"unpack_data",
"unpack_info",
"ray_resampling",
"loss_distortion",
]
32 changes: 32 additions & 0 deletions nerfacc/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torch import Tensor

from .pack import unpack_data


def distortion(
packed_info: Tensor, weights: Tensor, t_starts: Tensor, t_ends: Tensor
) -> Tensor:
"""Distortion loss from Mip-NeRF 360 paper, Equ. 15.
Args:
packed_info: Packed info for the samples. (n_rays, 2)
weights: Weights for the samples. (all_samples,)
t_starts: Per-sample start distance. Tensor with shape (all_samples, 1).
t_ends: Per-sample end distance. Tensor with shape (all_samples, 1).
Returns:
Distortion loss. (n_rays,)
"""
# (all_samples, 1) -> (n_rays, n_samples)
w = unpack_data(packed_info, weights[..., None]).squeeze(-1)
t1 = unpack_data(packed_info, t_starts).squeeze(-1)
t2 = unpack_data(packed_info, t_ends).squeeze(-1)

interval = t2 - t1
tmid = (t1 + t2) / 2

loss_uni = (1 / 3) * (interval * w.pow(2)).sum(-1)
ww = w.unsqueeze(-1) * w.unsqueeze(-2)
mm = (tmid.unsqueeze(-1) - tmid.unsqueeze(-2)).abs()
loss_bi = (ww * mm).sum((-1, -2))
return loss_uni + loss_bi
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "nerfacc"
version = "0.1.7"
version = "0.1.8"
description = "A General NeRF Acceleration Toolbox."
readme = "README.md"
authors = [{name = "Ruilong", email = "[email protected]"}]
Expand Down
31 changes: 31 additions & 0 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import torch

from nerfacc import ray_marching
from nerfacc.losses import distortion

device = "cuda:0"
batch_size = 32
eps = 1e-6


@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_distortion():
rays_o = torch.rand((batch_size, 3), device=device)
rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)

packed_info, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=1e-3,
)
weights = torch.rand((t_starts.shape[0],), device=device)
loss = distortion(packed_info, weights, t_starts, t_ends)
assert loss.shape == (batch_size,)


if __name__ == "__main__":
test_distortion()

0 comments on commit 3d95832

Please sign in to comment.