Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloading Revamp #3216

Open
wants to merge 81 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
0471543
initial debugging and testing works
AntonioMacaronio Jun 11, 2024
c6dde7d
pwais changes with RayBatchStream to alleviate training
AntonioMacaronio Jun 12, 2024
a09ea0c
Merge branch 'main' into dataloading-revamp
AntonioMacaronio Jun 12, 2024
78453cd
few bugs to iron out with multiprocessing, specifically pickled colla…
AntonioMacaronio Jun 12, 2024
f2bd96f
working version of RayBatchStream
AntonioMacaronio Jun 13, 2024
d8b7430
additional docstrings
AntonioMacaronio Jun 13, 2024
a5425d4
cleanup
AntonioMacaronio Jun 13, 2024
604f734
much more documentation
AntonioMacaronio Jun 13, 2024
0143803
successfully trained AEA-script2_seq2 closed_loop without OOM
AntonioMacaronio Jun 13, 2024
d3527e2
porting over aria dataset-size feature
AntonioMacaronio Jun 13, 2024
25f5f27
added logic to handle eviction of a worker's cached_collated_batch
AntonioMacaronio Jun 14, 2024
3a8b63b
antonio's implementation of stream batches
AntonioMacaronio Jun 15, 2024
536c6ca
training on a dataset with 4000 images works!
AntonioMacaronio Jun 15, 2024
43a0061
some configuration speedups, loops aren't actually needed!
AntonioMacaronio Jun 15, 2024
fa7cf30
quick fix adjustment to aria
AntonioMacaronio Jun 15, 2024
927cb6a
removed unnecessary looping
AntonioMacaronio Jun 16, 2024
814f2c2
much faster training when adding i variable to collate every 5 ray bu…
AntonioMacaronio Jun 25, 2024
247ac3e
cleanup unnecssary variables in Dataloader
AntonioMacaronio Jul 7, 2024
55d0803
further cleanup
AntonioMacaronio Jul 11, 2024
b6979a4
adding caching of compressed images to RAM to reduce disk bottleneck
AntonioMacaronio Jul 20, 2024
81dbf7c
added caching to RAM for masks
AntonioMacaronio Jul 22, 2024
55ca71d
found fast way to collate - many tricks applied
AntonioMacaronio Jul 26, 2024
3b4f091
quick update to aria to test on different datasets
AntonioMacaronio Jul 26, 2024
7de1922
cleaned up the accelerated pil_to_numpy function
AntonioMacaronio Jul 26, 2024
9ceaad1
cleaning up PR
AntonioMacaronio Jul 26, 2024
4147a6a
this commit was used to generate the time metrics and profiling metrics
AntonioMacaronio Jul 26, 2024
5a55b7a
REAL commit used to run tests
AntonioMacaronio Jul 26, 2024
78f02e6
testing with nerfacto-big
AntonioMacaronio Aug 15, 2024
19bc4b5
generated RayBundle collate and converting images from uint8s to floa…
AntonioMacaronio Aug 15, 2024
9245d05
updating nerfacto to support uint8 easily, will need to figure out a …
AntonioMacaronio Aug 20, 2024
3124c14
datamanager updates, both splat and nerf
AntonioMacaronio Aug 20, 2024
afb0612
must use writeable arrays because torch requires them
AntonioMacaronio Aug 20, 2024
288a740
cleaned up base_dataset, added pickle to utils, more code in full_ima…
AntonioMacaronio Aug 22, 2024
2fd0862
lots of process on a parallel FullImageDatamanger
AntonioMacaronio Aug 23, 2024
846e2f3
can train big splats with pre-assertion hack or ROI hack and 0 workers
AntonioMacaronio Aug 24, 2024
8fb0b4d
fixed all undistortion issues with ParallelImageDatamanager
AntonioMacaronio Aug 27, 2024
ce3f83f
adding some downsampling and parallel tests with splatfacto!
AntonioMacaronio Aug 31, 2024
8ab9963
deleted commented code in dataloaders.py and added bugfix to shuffling
AntonioMacaronio Aug 31, 2024
c9e16bf
testing splatfacto-big
AntonioMacaronio Sep 1, 2024
ddac38d
cleaned up base_pipeline.py
AntonioMacaronio Sep 1, 2024
443719a
cleaned up base_pipeline.py ACTUALLY THIS TIME, forgot to save last time
AntonioMacaronio Sep 1, 2024
d16e519
cleaned up a lot of code
AntonioMacaronio Sep 1, 2024
367d512
process_project_aria back to main branch and some cleanup in full_ima…
AntonioMacaronio Sep 1, 2024
d3d99b4
clarifying docstrings
AntonioMacaronio Sep 1, 2024
6f763dc
further PR cleanup
AntonioMacaronio Sep 3, 2024
a5191bd
updating models
AntonioMacaronio Sep 9, 2024
7db70dc
further cleanup
AntonioMacaronio Sep 9, 2024
5c3262b
removed caching of images into bytestrings
AntonioMacaronio Sep 9, 2024
ff2bda1
adding caching of compressed images to RAM, forgot that hardware matters
AntonioMacaronio Sep 9, 2024
f6dd7dd
removing oom methods, adding the ability to add a flag to dataloading
AntonioMacaronio Sep 15, 2024
a6602c7
removed CacheDataloader, moved RayBatchStream to dataloaders.py, new …
AntonioMacaronio Sep 15, 2024
3dc2031
fixing base_piplines, deleting a weird datamanager_configs file that …
AntonioMacaronio Sep 15, 2024
89f3d98
cleaning up next_train
AntonioMacaronio Sep 15, 2024
14e60e5
replaced parallel datamanager with new datamanager
AntonioMacaronio Sep 19, 2024
204dfb2
reverted the original base_datamanager.py, new datamanager replaced p…
AntonioMacaronio Sep 19, 2024
5864bc9
modified VanillaConfig, but VanillaDataManager is the same as before
AntonioMacaronio Sep 19, 2024
6d97de3
cleaning up, 2 datamanagers now - original and new parallel one
AntonioMacaronio Sep 19, 2024
1f34017
able to train with new nerfstudio dataloader now
AntonioMacaronio Sep 19, 2024
99cf86a
side by side datamanagers, moved tons of logic into dataloaders.py an…
AntonioMacaronio Sep 23, 2024
4ebad85
added custom ray processing API to support implementations like LERF,…
AntonioMacaronio Sep 23, 2024
87921be
adding functionality for ns-eval by adding FixedIndicesEvalDataloader…
AntonioMacaronio Sep 24, 2024
b628c7c
adding both ray API and image-view API to datamanagers for custom par…
AntonioMacaronio Sep 27, 2024
d2785d1
updating splatfacto config for 4k tests
AntonioMacaronio Sep 30, 2024
436af9d
updating docstrings to be more descriptive
AntonioMacaronio Sep 30, 2024
dd4daaa
new datamanager API breaks when setup_eval() has multiple workers, no…
AntonioMacaronio Sep 30, 2024
43c66ae
adding custom_view_processor to ImageBatchStream
AntonioMacaronio Sep 30, 2024
ba81e11
merging with main!
AntonioMacaronio Sep 30, 2024
1922566
reverting full_images_datamanager to main branch
AntonioMacaronio Oct 1, 2024
beb74be
removing nn.Module inheritance from Datamanager class
AntonioMacaronio Oct 1, 2024
087cff0
don't need to move datamanger to device anymore since Datamanager is …
AntonioMacaronio Oct 1, 2024
48e6d15
finished integration test with nerfacto
AntonioMacaronio Oct 4, 2024
3f1799b
simplified config variables, integrated the parallelism/disk-data-loa…
AntonioMacaronio Oct 25, 2024
f46aa42
updated the splatfacto config to be simpler with the dataloading and …
AntonioMacaronio Oct 25, 2024
5aa51fb
style checks and some cleanup
AntonioMacaronio Oct 25, 2024
ec3c12a
new splatfacto test, cleaning up nerfacto integration test
AntonioMacaronio Oct 25, 2024
82bc5b2
removing redundant parallel_full_images_datamaanger, as the OG full_i…
AntonioMacaronio Oct 26, 2024
377a56a
Merge branch 'main' into dataloading-revamp
AntonioMacaronio Oct 28, 2024
bbb5473
ruff linting and pyright fixing
AntonioMacaronio Oct 28, 2024
2e64120
further pyright fixing
AntonioMacaronio Oct 28, 2024
e9c2fd6
another pyright fixing
AntonioMacaronio Oct 28, 2024
e4dc9f9
fixing pyright error, camera optimization no longer part of datamanager
AntonioMacaronio Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager
from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig
Expand All @@ -37,6 +37,7 @@
from nerfstudio.data.dataparsers.phototourism_dataparser import PhototourismDataParserConfig
from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig
from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.datasets.depth_dataset import DepthDataset
from nerfstudio.data.datasets.sdf_dataset import SDFDataset
from nerfstudio.data.datasets.semantic_dataset import SemanticDataset
Expand Down Expand Up @@ -91,10 +92,13 @@
max_num_iterations=30000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
_target=ParallelDataManager[InputDataset],
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
load_from_disk=True,
use_parallel_dataloader=True,
),
model=NerfactoModelConfig(
eval_num_rays_per_chunk=1 << 15,
Expand Down Expand Up @@ -127,10 +131,13 @@
max_num_iterations=100000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
_target=ParallelDataManager[InputDataset],
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=8192,
eval_num_rays_per_batch=4096,
load_from_disk=True,
use_parallel_dataloader=True,
),
model=NerfactoModelConfig(
eval_num_rays_per_chunk=1 << 15,
Expand Down Expand Up @@ -171,7 +178,7 @@
max_num_iterations=100000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=16384,
eval_num_rays_per_batch=4096,
Expand Down Expand Up @@ -302,7 +309,7 @@
method_configs["mipnerf"] = TrainerConfig(
method_name="mipnerf",
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=1024),
datamanager=VanillaDataManagerConfig(dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=1024),
model=VanillaModelConfig(
_target=MipNerfModel,
loss_coefficients={"rgb_loss_coarse": 0.1, "rgb_loss_fine": 1.0},
Expand Down Expand Up @@ -375,7 +382,7 @@
max_num_iterations=30000,
mixed_precision=False,
pipeline=VanillaPipelineConfig(
datamanager=ParallelDataManagerConfig(
datamanager=VanillaDataManagerConfig(
dataparser=BlenderDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
Expand Down Expand Up @@ -599,8 +606,10 @@
mixed_precision=False,
pipeline=VanillaPipelineConfig(
datamanager=FullImageDatamanagerConfig(
dataparser=NerfstudioDataParserConfig(load_3D_points=True),
dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1),
# dataparser=NerfstudioDataParserConfig(load_3D_points=True),
cache_images_type="uint8",
cache_images="disk",
),
model=SplatfactoModelConfig(),
),
Expand Down Expand Up @@ -656,8 +665,10 @@
mixed_precision=False,
pipeline=VanillaPipelineConfig(
datamanager=FullImageDatamanagerConfig(
dataparser=NerfstudioDataParserConfig(load_3D_points=True),
dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1),
# dataparser=NerfstudioDataParserConfig(load_3D_points=True),
cache_images_type="uint8",
cache_images="disk",
),
model=SplatfactoModelConfig(
cull_alpha_thresh=0.005,
Expand Down
44 changes: 39 additions & 5 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

import torch
import tyro
from torch import nn
from torch.nn import Parameter
from torch.utils.data.distributed import DistributedSampler
from typing_extensions import TypeVar
Expand Down Expand Up @@ -86,7 +85,6 @@ def variable_res_collate(batch: List[Dict]) -> Dict:
# now that iteration is complete, the image data items can be removed from the batch
for key in topop:
del data[key]

new_batch = nerfstudio_collate(batch)
new_batch["image"] = images
new_batch.update(imgdata_lists)
Expand All @@ -111,7 +109,7 @@ class DataManagerConfig(InstantiateConfig):
"""Process images on GPU for speed at the expense of memory, if True."""


class DataManager(nn.Module):
class DataManager:
"""Generic data manager's abstract class

This version of the data manager is designed be a monolithic way to load data and latents,
Expand Down Expand Up @@ -311,6 +309,8 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""Target class to instantiate."""
dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig)
"""Specifies the dataparser used to unpack the data."""
cache_images_type: Literal["uint8", "float32"] = "float32"
"""The image type returned from manager, caching images in uint8 saves memory"""
train_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per training iteration."""
train_num_images_to_sample_from: int = -1
AntonioMacaronio marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -331,10 +331,22 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""Specifies the collate function to use for the train and eval dataloaders."""
camera_res_scale_factor: float = 1.0
"""The scale factor for scaling spatial data such as images, mask, semantics
along with relevant information about camera intrinsics
"""
along with relevant information about camera intrinsics"""
patch_size: int = 1
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
use_parallel_dataloader: bool = False
"""Allows parallelization of the dataloading process with multiple workers prefetching RayBundles."""
load_from_disk: bool = False
"""If True, conserves RAM memory by loading images from disk.
If False, caches all the images as tensors to RAM and loads from RAM."""
dataloader_num_workers: int = 0
"""The number of workers performing the dataloading from either disk/RAM, which
includes collating, pixel sampling, unprojecting, ray generation etc."""
prefetch_factor: int | None = None
"""The limit number of batches a worker will start loading once an iterator is created.
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
cache_compressed_images: bool = False
"""If True, cache raw image files as byte strings to RAM."""

# tyro.conf.Suppress prevents us from creating CLI arguments for this field.
camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None)
Expand All @@ -352,6 +364,26 @@ def __post_init__(self):
)
warnings.warn("above message coming from", FutureWarning, stacklevel=3)

"""
These heuristics allow the CPU dataloading bottleneck to equal the GPU bottleneck when training, but can be adjusted
Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck.
"""
if self.load_from_disk:
self.train_num_images_to_sample_from = (
50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from
)
self.train_num_times_to_repeat_images = (
10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images
)
self.prefetch_factor = self.train_num_times_to_repeat_images if self.use_parallel_dataloader else None

if self.use_parallel_dataloader:
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
self.dataloader_num_workers = 4 if self.dataloader_num_workers == 0 else self.dataloader_num_workers


TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset)

Expand Down Expand Up @@ -451,13 +483,15 @@ def create_train_dataset(self) -> TDataset:
return self.dataset_type(
dataparser_outputs=self.train_dataparser_outputs,
scale_factor=self.config.camera_res_scale_factor,
cache_compressed_images=self.config.cache_compressed_images,
)

def create_eval_dataset(self) -> TDataset:
"""Sets up the data loaders for evaluation"""
return self.dataset_type(
dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split),
scale_factor=self.config.camera_res_scale_factor,
cache_compressed_images=self.config.cache_compressed_images,
)

def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler:
Expand Down
Loading
Loading