Skip to content

Commit

Permalink
example with vanilla nerf
Browse files Browse the repository at this point in the history
  • Loading branch information
liruilong940607 committed Sep 17, 2022
1 parent f66bacf commit 08761ab
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ python examples/trainval.py vanilla --train_split train

Performance on test set:

| | Lego |
| - | - |
| Paper PSNR (train set) | 32.54 |
| Our PSNR (train set) | 33.21 |
| Our PSNR (trainval set) | 33.66 |
| Our train time & test FPS | 45min; 0.43FPS |
| | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - |
| Paper PSNR (train set) | 32.54 | 32.91 | 29.62 | 33.00 | 36.18 |
| Our PSNR (train set) | 33.21 | 33.36 | 29.48 | 32.79 | 35.54 |
| Our PSNR (trainval set) | 33.66 | - | - | - | - | - |
| Our train time & test FPS | 45min; 0.43FPS | 44min; 1FPS | 37min; 0.33FPS* | 44min; 0.57FPS* | 50min; 0.15 FPS* |

For reference, vanilla NeRF paper trains on V100 GPU for 1-2 days per scene. Test time rendering takes about 30 secs to render a 800x800 image. Our model is trained on a TITAN X.

Note: We only use a single MLP with more samples (1024), instead of two MLPs with coarse-to-fine sampling as in the paper. Both ways share the same spirit to do dense sampling around the surface. Our fast rendering inheritly skip samples away from the surface so we can simplly increase the number of samples with a single MLP, to achieve the same goal with coarse-to-fine sampling, without runtime or memory issue.

*FPS for some scenes are tested under `--test_chunk_size=8192` (default is `81920`) to avoid OOM.

<!--
Tested with the default settings on the Lego test set.
Expand Down
1 change: 1 addition & 0 deletions examples/datasets/nerf_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class SubjectLoader(torch.utils.data.Dataset):
"lego",
"materials",
"mic",
"ship",
]

WIDTH, HEIGHT = 800, 800
Expand Down
35 changes: 31 additions & 4 deletions examples/trainval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
TARGET_SAMPLE_BATCH_SIZE = 1 << 16


def render_image(radiance_field, rays, render_bkgd, render_step_size):
def render_image(
radiance_field, rays, render_bkgd, render_step_size, test_chunk_size=81920
):
"""Render the pixels of an image.
Args:
Expand Down Expand Up @@ -48,7 +50,7 @@ def sigma_rgb_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends):
return radiance_field(positions, frustum_dirs)

results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920
chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_results = volumetric_rendering(
Expand Down Expand Up @@ -95,10 +97,31 @@ def sigma_rgb_fn(frustum_origins, frustum_dirs, frustum_starts, frustum_ends):
choices=["train", "trainval"],
help="which train split to use",
)
parser.add_argument(
"--scene",
type=str,
default="lego",
choices=[
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
"ship",
],
help="which scene to use",
)
parser.add_argument(
"--test_chunk_size",
type=int,
default=81920,
)
args = parser.parse_args()

device = "cuda:0"
scene = "lego"
scene = args.scene

# setup the scene bounding box.
scene_aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5])
Expand Down Expand Up @@ -251,7 +274,11 @@ def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
render_bkgd = data["color_bkgd"].to(device)
# rendering
rgb, acc, _, _ = render_image(
radiance_field, rays, render_bkgd, render_step_size
radiance_field,
rays,
render_bkgd,
render_step_size,
test_chunk_size=args.test_chunk_size,
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
Expand Down
1 change: 0 additions & 1 deletion nerfacc/volumetric_rendering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Callable, List, Optional, Tuple


import torch

from .utils import (
Expand Down

0 comments on commit 08761ab

Please sign in to comment.