From 0471543c2064dcebadc7654e644a55da63b47788 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Tue, 11 Jun 2024 15:03:05 -0700 Subject: [PATCH 01/78] initial debugging and testing works --- nerfstudio/configs/method_configs.py | 2 +- .../data/datamanagers/base_datamanager.py | 1 + nerfstudio/data/pixel_samplers.py | 4 +- .../scripts/datasets/process_project_aria.py | 297 ++++++++++++++---- 4 files changed, 240 insertions(+), 64 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index d77959fb9f..c15233feae 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -90,7 +90,7 @@ max_num_iterations=30000, mixed_precision=True, pipeline=VanillaPipelineConfig( - datamanager=ParallelDataManagerConfig( + datamanager=VanillaDataManagerConfig( dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=4096, eval_num_rays_per_batch=4096, diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index cff03607b9..36e0574109 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -529,6 +529,7 @@ def setup_eval(self): def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" + breakpoint() self.train_count += 1 image_batch = next(self.iter_train_image_dataloader) assert self.train_pixel_sampler is not None diff --git a/nerfstudio/data/pixel_samplers.py b/nerfstudio/data/pixel_samplers.py index ad11ee4094..c355ae7830 100644 --- a/nerfstudio/data/pixel_samplers.py +++ b/nerfstudio/data/pixel_samplers.py @@ -304,7 +304,9 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int, assert num_rays_per_batch % 2 == 0, "num_rays_per_batch must be divisible by 2" num_rays_per_image = divide_rays_per_image(num_rays_per_batch, num_images) - + # print(batch.keys()) + # import time + # time.sleep(3) if "mask" in batch: for i, num_rays in enumerate(num_rays_per_image): image_height, image_width, _ = batch["image"][i].shape diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index fe48748325..4ebda7e4e8 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -16,8 +16,9 @@ import sys import threading from dataclasses import dataclass +from itertools import zip_longest from pathlib import Path -from typing import Any, Dict, List, cast +from typing import Any, Dict, List, Literal, Optional, Tuple, cast import numpy as np import open3d as o3d @@ -25,8 +26,9 @@ from PIL import Image try: - from projectaria_tools.core import mps + from projectaria_tools.core import calibration, mps from projectaria_tools.core.data_provider import VrsDataProvider, create_vrs_data_provider + from projectaria_tools.core.image import InterpolationMethod from projectaria_tools.core.mps.utils import filter_points_from_confidence from projectaria_tools.core.sophus import SE3 except ImportError: @@ -68,6 +70,7 @@ class AriaImageFrame: file_path: str t_world_camera: SE3 timestamp_ns: float + pinhole_intrinsic: Tuple[float, float, float, float] @dataclass @@ -76,11 +79,12 @@ class TimedPoses: t_world_devices: List[SE3] -def get_camera_calibs(provider: VrsDataProvider) -> Dict[str, AriaCameraCalibration]: +def get_camera_calibs( + provider: VrsDataProvider, name: Literal["camera-rgb", "camera-slam-left", "camera-slam-right"] = "camera-rgb" +) -> AriaCameraCalibration: """Retrieve the per-camera factory calibration from within the VRS.""" - + assert name in ["camera-rgb", "camera-slam-left", "camera-slam-right"], f"{name} is not a valid camera sensor" factory_calib = {} - name = "camera-rgb" device_calib = provider.get_device_calibration() assert device_calib is not None, "Could not find device calibration" sensor_calib = device_calib.get_camera_calib(name) @@ -101,7 +105,7 @@ def get_camera_calibs(provider: VrsDataProvider) -> Dict[str, AriaCameraCalibrat t_device_camera=sensor_calib.get_transform_device_camera(), ) - return factory_calib + return factory_calib[name] def read_trajectory_csv_to_dict(file_iterable_csv: str) -> TimedPoses: @@ -118,25 +122,101 @@ def read_trajectory_csv_to_dict(file_iterable_csv: str) -> TimedPoses: ) +def undistort_image_and_calibration( + input_image: np.ndarray, + input_calib: calibration.CameraCalibration, + output_focal_length: int, +) -> Tuple[np.ndarray, calibration.CameraCalibration]: + """ + Return the undistorted image and the updated camera calibration. + """ + input_calib_width = input_calib.get_image_size()[0] + input_calib_height = input_calib.get_image_size()[1] + if input_image.shape[1] != input_calib_width or input_image.shape[0] != input_calib_height: + raise ValueError( + f"Input image shape {input_image.shape} does not match calibration {input_calib.get_image_size()}" + ) + + # Undistort the image + pinhole_calib = calibration.get_linear_camera_calibration( + int(input_calib_width), + int(input_calib_height), + output_focal_length, + "pinhole", + input_calib.get_transform_device_camera(), + ) + output_image = calibration.distort_by_calibration( + input_image, pinhole_calib, input_calib, InterpolationMethod.BILINEAR + ) + + return output_image, pinhole_calib + + +def rotate_upright_image_and_calibration( + input_image: np.ndarray, + input_calib: calibration.CameraCalibration, +) -> Tuple[np.ndarray, calibration.CameraCalibration]: + """ + Return the rotated upright image and update both intrinsics and extrinsics of the camera calibration + NOTE: This function only supports pinhole and fisheye624 camera model. + """ + output_image = np.rot90(input_image, k=3) + updated_calib = calibration.rotate_camera_calib_cw90deg(input_calib) + + return output_image, updated_calib + + +def generate_circular_mask(numRows: int, numCols: int, radius: float) -> np.ndarray: + """ + Generates a mask where a circle in the center of the image with input radius is white (sampled from). + Everything outside the circle is black (masked out) + """ + # Calculate the center coordinates + rows, cols = np.ogrid[:numRows, :numCols] + center_row, center_col = numRows // 2, numCols // 2 + + # Calculate the distance of each pixel from the center + distance_from_center = np.sqrt((rows - center_row) ** 2 + (cols - center_col) ** 2) + mask = np.zeros((numRows, numCols), dtype=np.uint8) + mask[distance_from_center <= radius] = 1 + return mask + + def to_aria_image_frame( provider: VrsDataProvider, index: int, name_to_camera: Dict[str, AriaCameraCalibration], t_world_devices: TimedPoses, output_dir: Path, + camera_name: str = "camera-rgb", + pinhole: bool = False, ) -> AriaImageFrame: - name = "camera-rgb" + aria_cam_calib = name_to_camera[camera_name] + stream_id = provider.get_stream_id_from_label(camera_name) + assert stream_id is not None, f"Could not find stream {camera_name}" - camera_calibration = name_to_camera[name] - stream_id = provider.get_stream_id_from_label(name) - assert stream_id is not None, f"Could not find stream {name}" + # Retrieve the current camera calibration + device_calib = provider.get_device_calibration() + assert device_calib is not None, "Could not find device calibration" + src_calib = device_calib.get_camera_calib(camera_name) + assert isinstance(src_calib, calibration.CameraCalibration), "src_calib is not of type CameraCalibration" - # Get the image corresponding to this index + # Get the image corresponding to this index and undistort it image_data = provider.get_image_data_by_index(stream_id, index) - img = Image.fromarray(image_data[0].to_numpy_array()) + image_array, intrinsic = image_data[0].to_numpy_array().astype(np.uint8), (0, 0, 0, 0) + if pinhole: + f_length = 500 if camera_name == "camera-rgb" else 170 + image_array, src_calib = undistort_image_and_calibration(image_array, src_calib, f_length) + intrinsic = (f_length, f_length, image_array.shape[1] // 2, image_array.shape[0] // 2) + + # Rotate the image right side up + image_array, src_calib = rotate_upright_image_and_calibration(image_array, src_calib) + img = Image.fromarray(image_array) capture_time_ns = image_data[1].capture_timestamp_ns + intrinsic = (intrinsic[0], intrinsic[1], intrinsic[3], intrinsic[2]) - file_path = f"{output_dir}/{name}_{capture_time_ns}.jpg" + # Save the image + file_path = f"{output_dir}/{camera_name}_{capture_time_ns}.jpg" threading.Thread(target=lambda: img.save(file_path)).start() # Find the nearest neighbor pose with the closest timestamp to the capture time. @@ -146,17 +226,46 @@ def to_aria_image_frame( t_world_device = t_world_devices.t_world_devices[nearest_pose_idx] # Compute the world to camera transform. - t_world_camera = t_world_device @ camera_calibration.t_device_camera @ T_ARIA_NERFSTUDIO + t_world_camera = t_world_device @ src_calib.get_transform_device_camera() @ T_ARIA_NERFSTUDIO + + # Define new AriaCameraCalibration since we rotated the image + width = src_calib.get_image_size()[0].item() + height = src_calib.get_image_size()[1].item() + intrinsics = src_calib.projection_params() + aria_cam_calib = AriaCameraCalibration( + fx=intrinsics[0], + fy=intrinsics[0], + cx=intrinsics[1], + cy=intrinsics[2], + distortion_params=intrinsics[3:15], + width=width, + height=height, + t_device_camera=src_calib.get_transform_device_camera(), + ) return AriaImageFrame( - camera=camera_calibration, + camera=aria_cam_calib, file_path=file_path, t_world_camera=t_world_camera, timestamp_ns=capture_time_ns, + pinhole_intrinsic=intrinsic, ) -def to_nerfstudio_frame(frame: AriaImageFrame) -> Dict: +def to_nerfstudio_frame(frame: AriaImageFrame, pinhole: bool = False, mask_path: str = "") -> Dict: + if pinhole: + return { + "fl_x": frame.pinhole_intrinsic[0], + "fl_y": frame.pinhole_intrinsic[1], + "cx": frame.pinhole_intrinsic[2], + "cy": frame.pinhole_intrinsic[3], + "w": frame.pinhole_intrinsic[2] * 2, + "h": frame.pinhole_intrinsic[3] * 2, + "file_path": frame.file_path, + "transform_matrix": frame.t_world_camera.to_matrix().tolist(), + "timestamp": frame.timestamp_ns, + "mask_path": mask_path, + } return { "fl_x": frame.camera.fx, "fl_y": frame.camera.fy, @@ -178,70 +287,134 @@ class ProcessProjectAria: https://facebookresearch.github.io/projectaria_tools/docs/ARK/mps. """ - vrs_file: Path - """Path to the VRS file.""" - mps_data_dir: Path + vrs_file: Tuple[Path, ...] + """Path to the VRS file(s).""" + mps_data_dir: Tuple[Path, ...] """Path to Project Aria Machine Perception Services (MPS) attachments.""" output_dir: Path """Path to the output directory.""" + points_file: Optional[Tuple[Path, ...]] = () + """Path to the point cloud file (usually called semidense_points.csv.gz) if not in the mps_data_dir""" + include_side_cameras: bool = False + """If True, include and process the images captured by the grayscale side cameras. If False, only uses the main RGB camera's data.""" def main(self) -> None: """Generate a nerfstudio dataset from ProjectAria data (VRS) and MPS attachments.""" - # Create output directory if it doesn't exist. + # Create output directory if it doesn't exist self.output_dir = self.output_dir.absolute() self.output_dir.mkdir(parents=True, exist_ok=True) - provider = create_vrs_data_provider(str(self.vrs_file.absolute())) - assert provider is not None, "Cannot open file" - - name_to_camera = get_camera_calibs(provider) - - print("Getting poses from closed loop trajectory CSV...") - trajectory_csv = self.mps_data_dir / "closed_loop_trajectory.csv" - t_world_devices = read_trajectory_csv_to_dict(str(trajectory_csv.absolute())) - - name = "camera-rgb" - stream_id = provider.get_stream_id_from_label(name) - - # create an AriaImageFrame for each image in the VRS. - print("Creating Aria frames...") - aria_frames = [ - to_aria_image_frame(provider, index, name_to_camera, t_world_devices, self.output_dir) - for index in range(0, provider.get_num_data(stream_id)) - ] - - # create the NerfStudio frames from the AriaImageFrames. - print("Creating NerfStudio frames...") - CANONICAL_RGB_VALID_RADIUS = 707.5 - CANONICAL_RGB_WIDTH = 1408 - rgb_valid_radius = CANONICAL_RGB_VALID_RADIUS * (aria_frames[0].camera.width / CANONICAL_RGB_WIDTH) + # Create list of tuples containing files from each wearer and output variables + assert len(self.vrs_file) == len( + self.mps_data_dir + ), "Please provide an Aria MPS attachment for each corresponding VRS file." + vrs_mps_points_triplets = list(zip_longest(self.vrs_file, self.mps_data_dir, self.points_file)) # type: ignore nerfstudio_frames = { - "camera_model": ARIA_CAMERA_MODEL, - "frames": [to_nerfstudio_frame(frame) for frame in aria_frames], - "fisheye_crop_radius": rgb_valid_radius, + "camera_model": "OPENCV" if self.include_side_cameras else ARIA_CAMERA_MODEL, + "frames": [], } - - # save global point cloud, which is useful for Gaussian Splatting. - points_path = self.mps_data_dir / "global_points.csv.gz" - if not points_path.exists(): - # MPS point cloud output was renamed in Aria's December 4th, 2023 update. - # https://facebookresearch.github.io/projectaria_tools/docs/ARK/sw_release_notes#project-aria-updates-aria-mobile-app-v140-and-changes-to-mps - points_path = self.mps_data_dir / "semidense_points.csv.gz" - - if points_path.exists(): - print("Found global points, saving to PLY...") - points_data = mps.read_global_point_cloud(str(points_path)) # type: ignore - points_data = filter_points_from_confidence(points_data) + points = [] + + # Process the aria data of each user one by one + for rec_i, (vrs_file, mps_data_dir, points_file) in enumerate(vrs_mps_points_triplets): + provider = create_vrs_data_provider(str(vrs_file.absolute())) + assert provider is not None, "Cannot open file" + + names = ["camera-rgb", "camera-slam-left", "camera-slam-right"] + name_to_camera = { + name: get_camera_calibs(provider, name) # type: ignore + for name in names + } # name_to_camera is of type Dict[str, AriaCameraCalibration] + + print(f"Getting poses from recording {rec_i + 1}'s closed loop trajectory CSV...") + trajectory_csv = mps_data_dir / "closed_loop_trajectory.csv" + t_world_devices = read_trajectory_csv_to_dict(str(trajectory_csv.absolute())) + + stream_ids = [provider.get_stream_id_from_label(name) for name in names] + + # Create an AriaImageFrame for each image in the VRS + print(f"Creating Aria frames for recording {rec_i + 1}...") + CANONICAL_RGB_VALID_RADIUS = 707.5 # radius of a circular mask that represents the valid area on the camera's sensor plane. Pixels out of this circular region are considered invalid + CANONICAL_RGB_WIDTH = 1408 + if not self.include_side_cameras: + aria_rgb_frames = [ + to_aria_image_frame( + provider, index, name_to_camera, t_world_devices, self.output_dir, camera_name=names[0] + ) + for index in range(0, provider.get_num_data(stream_ids[0])) + ] + print(f"Creating NerfStudio frames for recording {rec_i + 1}...") + nerfstudio_frames["frames"] += [to_nerfstudio_frame(frame) for frame in aria_rgb_frames] + rgb_valid_radius = CANONICAL_RGB_VALID_RADIUS * ( + aria_rgb_frames[0].camera.width / CANONICAL_RGB_WIDTH + ) # to handle both high-res 2880 x 2880 aria captures + nerfstudio_frames["fisheye_crop_radius"] = rgb_valid_radius + else: + aria_all3cameras_pinhole_frames = [ + [ + to_aria_image_frame( + provider, + index, + name_to_camera, + t_world_devices, + self.output_dir, + camera_name=names[i], + pinhole=True, + ) + for index in range(0, provider.get_num_data(stream_id)) + ] + for i, stream_id in enumerate(stream_ids) + ] + # Generate masks for undistorted images + rgb_width = aria_all3cameras_pinhole_frames[0][0].camera.width + rgb_valid_radius = CANONICAL_RGB_VALID_RADIUS * (rgb_width / CANONICAL_RGB_WIDTH) + slam_valid_radius = 330.0 # found here: https://github.com/facebookresearch/projectaria_tools/blob/4aee633cb667ab927825dc10477cad0df8393a34/core/calibration/loader/SensorCalibrationJson.cpp#L102C5-L104C18 + rgb_mask_nparray, slam_mask_nparray = ( + generate_circular_mask(rgb_width, rgb_width, rgb_valid_radius), + generate_circular_mask(640, 480, slam_valid_radius), + ) + rgb_mask_filepath, slam_mask_filepath = ( + f"{self.output_dir}/rgb_mask.jpg", + f"{self.output_dir}/slam_mask.jpg", + ) + Image.fromarray(rgb_mask_nparray).save(rgb_mask_filepath) + Image.fromarray(slam_mask_nparray).save(slam_mask_filepath) + + print(f"Creating NerfStudio frames for recording {rec_i + 1}...") + mask_filepaths = [rgb_mask_filepath, slam_mask_filepath, slam_mask_filepath] + pinhole_frames = [ + to_nerfstudio_frame(frame, pinhole=True, mask_path=mask_filepath) + for i, mask_filepath in enumerate(mask_filepaths) + for frame in aria_all3cameras_pinhole_frames[i] + ] + nerfstudio_frames["frames"] += pinhole_frames + + if points_file: + points_path = points_file + else: + points_path = mps_data_dir / "global_points.csv.gz" + if not points_path.exists(): + # MPS point cloud output was renamed in Aria's December 4th, 2023 update. + # https://facebookresearch.github.io/projectaria_tools/docs/ARK/sw_release_notes#project-aria-updates-aria-mobile-app-v140-and-changes-to-mps + points_path = mps_data_dir / "semidense_points.csv.gz" + + if points_path.exists(): + print(f"Found global points for recording {rec_i+1}") + points_data = mps.read_global_point_cloud(str(points_path)) # type: ignore + points_data = filter_points_from_confidence(points_data) + points += [cast(Any, it).position_world for it in points_data] + + if points: + print("Saving found points to PLY...") pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(np.array([cast(Any, it).position_world for it in points_data])) + pcd.points = o3d.utility.Vector3dVector(np.array(points)) ply_file_path = self.output_dir / "global_points.ply" o3d.io.write_point_cloud(str(ply_file_path), pcd) - nerfstudio_frames["ply_file_path"] = "global_points.ply" else: print("No global points found!") - # write the json out to disk as transforms.json + # Write the json out to disk as transforms.json print("Writing transforms.json") transform_file = self.output_dir / "transforms.json" with open(transform_file, "w", encoding="UTF-8"): From c6dde7d511a8fd2da626d5d37e23b4ea1c4148b0 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Wed, 12 Jun 2024 04:21:28 -0700 Subject: [PATCH 02/78] pwais changes with RayBatchStream to alleviate training --- .../data/datamanagers/base_datamanager.py | 94 ++++++++--- nerfstudio/data/utils/dataloaders.py | 158 ++++++++++++++++++ 2 files changed, 233 insertions(+), 19 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 36e0574109..ae2d2623bc 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -484,19 +484,68 @@ def setup_train(self): """Sets up the data loaders for training""" assert self.train_dataset is not None CONSOLE.print("Setting up training dataset...") - self.train_image_dataloader = CacheDataloader( - self.train_dataset, - num_images_to_sample_from=self.config.train_num_images_to_sample_from, - num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, - device=self.device, - num_workers=self.world_size * 4, - pin_memory=True, - collate_fn=self.config.collate_fn, - exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, - ) - self.iter_train_image_dataloader = iter(self.train_image_dataloader) - self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) - self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device)) + # self.train_image_dataloader = CacheDataloader( + # self.train_dataset, + # num_images_to_sample_from=self.config.train_num_images_to_sample_from, + # num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, + # device=self.device, + # num_workers=self.world_size * 4, + # pin_memory=True, + # collate_fn=self.config.collate_fn, + # exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, + # ) + # self.iter_train_image_dataloader = iter(self.train_image_dataloader) + # self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) + # self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device)) + + if self.config.use_ray_train_dataloader: + self.raybatch_stream = RayBatchStream( + self.train_dataset, + self.config, + # self.train_pixel_sampler, + # self.train_ray_generator, + num_images_to_sample_from=self.config.train_num_images_to_sample_from, + num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, + device=self.device, + num_workers=self.world_size * 4, + pin_memory=True, + # device=self.device, + ) + self.ray_dataloader = torch.utils.data.DataLoader( + self.raybatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + prefetch_factor=self.config.dataloader_prefetch_size, + shuffle=False, + pin_memory=False, + # Our dataset does batching / collation + collate_fn=lambda x: x, + # pin_memory_device=self.device + ) + self.iter_train_image_dataloader = None + self.iter_train_raybundles = iter(self.ray_dataloader) + else: + self.iter_train_raybundles = None + self.train_image_dataloader = CacheDataloader( + self.train_dataset, + num_images_to_sample_from=self.config.train_num_images_to_sample_from, + num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, + device=self.device, + num_workers= + self.world_size * 4 + if self.config.dataloader_num_workers == -1 + else self.config.dataloader_num_workers, + prefetch_factor= + 2 + if self.config.dataloader_prefetch_size == -1 + else self.config.dataloader_prefetch_size, + pin_memory=True, + collate_fn=self.config.collate_fn, + exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, + ) + self.iter_train_image_dataloader = iter(self.train_image_dataloader) + self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) + self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device)) def setup_eval(self): """Sets up the data loader for evaluation""" @@ -531,12 +580,19 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" breakpoint() self.train_count += 1 - image_batch = next(self.iter_train_image_dataloader) - assert self.train_pixel_sampler is not None - assert isinstance(image_batch, dict) - batch = self.train_pixel_sampler.sample(image_batch) - ray_indices = batch["indices"] - ray_bundle = self.train_ray_generator(ray_indices) + if self.config.use_ray_train_dataloader: + ret = next(self.iter_train_raybundles) + assert len(ret) == 1, f"batch size should be one {len(ret)}" + ray_bundle, batch = ret[0] + # ray_bundle = RayBundle.from_dict(ray_bundle_dict) + ray_bundle = ray_bundle.to(self.device) + else: + image_batch = next(self.iter_train_image_dataloader) + assert self.train_pixel_sampler is not None + assert isinstance(image_batch, dict) + batch = self.train_pixel_sampler.sample(image_batch) + ray_indices = batch["indices"] + ray_bundle = self.train_ray_generator(ray_indices) return ray_bundle, batch def next_eval(self, step: int) -> Tuple[RayBundle, Dict]: diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 6a64ba738b..f5fa88f94b 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -31,7 +31,9 @@ from nerfstudio.cameras.cameras import Cameras from nerfstudio.cameras.rays import RayBundle from nerfstudio.data.datasets.base_dataset import InputDataset +from nerfstudio.data.datamanagers.base_datamanager import DataManagerConfig from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate +from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import get_dict_to_torch from nerfstudio.utils.rich_utils import CONSOLE @@ -146,6 +148,162 @@ def __iter__(self): yield collated_batch +class RayBatchStream(torch.utils.data.IterableDataset): + def __init__( + self, + input_dataset: Dataset, + datamanager_config : DataManagerConfig, + num_images_to_sample_from: int = -1, + device: Union[torch.device, str] = "cpu", + collate_fn: Callable[[Any], Any] = nerfstudio_collate, + exclude_batch_keys_from_device: Optional[List[str]] = None, + num_image_load_threads : int = 2, + cache_all_n_shard_per_worker : bool = True, + ): + if exclude_batch_keys_from_device is None: + exclude_batch_keys_from_device = ["image"] + self.input_dataset = input_dataset + assert isinstance(self.input_dataset, Sized) + + # super().__init__(dataset=dataset, **kwargs) # This will set self.dataset + + # self.num_times_to_repeat_images = num_times_to_repeat_images + # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) + # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from + self.num_images_to_sample_from = num_images_to_sample_from + self.device = device + self.collate_fn = collate_fn + # self.num_workers = kwargs.get("num_workers", 32) # nb only 4 in defaults + self.num_image_load_threads = num_image_load_threads #kwargs.get("num_workers", 4) # nb only 4 in defaults + self.exclude_batch_keys_from_device = exclude_batch_keys_from_device + + self.datamanager_config = datamanager_config + self.pixel_sampler = None + self.ray_generator = None + self._cached_collated_batch = None + self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker + + def _get_batch_list(self, indices=None): + """Returns a list of batches from the dataset attribute.""" + + assert isinstance(self.input_dataset, Sized) + if not indices: + indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) + # indices = range(len(self.input_dataset)) + batch_list = [] + results = [] + + # num_threads = int(self.num_ds_load_threads) * 4 + num_threads = ( + int(self.num_image_load_threads) if not self.cache_all_n_shard_per_worker + else 4 * int(self.num_image_load_threads)) + num_threads = min(num_threads, multiprocessing.cpu_count() - 1) + num_threads = max(num_threads, 1) + # print('num_threads', num_threads) + + # NB: this is I/O heavy, hence multi-threaded inside the worker + from tqdm.auto import tqdm + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for idx in indices: + res = executor.submit(self.input_dataset.__getitem__, idx) + results.append(res) + + # for res in track(results, description="Loading data batch", transient=True): + # for res in tqdm(results, desc='_get_batch_list'): + if self.cache_all_n_shard_per_worker: + results = tqdm(results) + for res in results: + batch_list.append(res.result()) + return batch_list + + def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: + """copy-pasta from VanillaDataManager.""" + from nerfstudio.cameras.cameras import Cameras, CameraType + from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig + + if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: + return PatchPixelSamplerConfig().setup( + patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch + ) + is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() + if is_equirectangular.any(): + CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") + + fisheye_crop_radius = None + if dataset.cameras.metadata is not None: + fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") + + return self.datamanager_config.pixel_sampler.setup( + is_equirectangular=is_equirectangular, + num_rays_per_batch=num_rays_per_batch, + fisheye_crop_radius=fisheye_crop_radius, + ) + + def _get_collated_batch(self, indices=None): + """Returns a collated batch.""" + batch_list = self._get_batch_list(indices=indices) + # print('running collate_fn', self.collate_fn) + collated_batch = self.collate_fn(batch_list) + # print('done collate_fn') + # assert False, (self.exclude_batch_keys_from_device, collated_batch) + collated_batch = get_dict_to_torch( + collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device + ) + # print('done get_dict_to_torch') + # print('_get_collated_batch') + return collated_batch + + def __iter__(self): + # Set up stuff now that we're in the worker process + if self.cache_all_n_shard_per_worker: + this_indices = list(range(len(self.input_dataset))) + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + print('TODO log. only single worker not sharding!') + worker_id = -1 + else: + # assign this worker a deterministic uniformly sampled slice + # of the dataset + import math + per_worker = int( + math.ceil(len(this_indices) / float(worker_info.num_workers))) + r = random.Random(1337) + r.shuffle(this_indices) + worker_id = worker_info.id + slice_start = worker_id * per_worker + this_indices = this_indices[slice_start:slice_start+per_worker] + print( + f'Worker ID {worker_id} working on {len(this_indices)} indices') + + import time + start = time.time() + print(f"Worker ID {worker_id} caching collated batch ...") + self._cached_collated_batch = self._get_collated_batch( + indices=this_indices) + print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") + + if self.pixel_sampler is None: + self.pixel_sampler = self._get_pixel_sampler( + self.input_dataset, + self.datamanager_config.train_num_rays_per_batch) + if self.ray_generator is None: + self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) + + # if self._cached_collated_batch is None: + # self._cached_collated_batch = self._get_collated_batch() + # print('did _cached_collated_batch') + while True: + if self._cached_collated_batch is None: + collated_batch = self._get_collated_batch() + else: + collated_batch = self._cached_collated_batch + # batch = self.pixel_sampler.sample(self._cached_collated_batch) + batch = self.pixel_sampler.sample(collated_batch) + ray_indices = batch["indices"] + ray_bundle = self.ray_generator(ray_indices) + yield ray_bundle, batch + + class EvalDataloader(DataLoader): """Evaluation dataloader base class From 78453cdcc0654748b7a07f98d335b695f73e983f Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Wed, 12 Jun 2024 06:21:29 -0700 Subject: [PATCH 03/78] few bugs to iron out with multiprocessing, specifically pickled collate_fn --- .../data/datamanagers/base_datamanager.py | 20 ++++---- nerfstudio/data/utils/dataloaders.py | 47 +++++++++---------- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index ae2d2623bc..6d4c687713 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -56,7 +56,7 @@ from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader +from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader, RayBatchStream from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.model_components.ray_generators import RayGenerator @@ -335,6 +335,9 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" + dataloader_prefetch_size : int = 2 + dataloader_num_workers : int = 16 + use_ray_train_dataloader: bool = True # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -499,17 +502,19 @@ def setup_train(self): # self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device)) if self.config.use_ray_train_dataloader: + import torch.multiprocessing as mp + mp.set_start_method('spawn') self.raybatch_stream = RayBatchStream( - self.train_dataset, - self.config, + input_dataset=self.train_dataset, + #self.config, # self.train_pixel_sampler, # self.train_ray_generator, num_images_to_sample_from=self.config.train_num_images_to_sample_from, - num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, + # num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, # no work device=self.device, - num_workers=self.world_size * 4, - pin_memory=True, - # device=self.device, + collate_fn=self.config.collate_fn, + # num_workers=self.world_size * 4,# this is part of torch.utils.data.DataLoader + # pin_memory=True, # this is part of torch.utils.data.DataLoader ) self.ray_dataloader = torch.utils.data.DataLoader( self.raybatch_stream, @@ -578,7 +583,6 @@ def setup_eval(self): def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" - breakpoint() self.train_count += 1 if self.config.use_ray_train_dataloader: ret = next(self.iter_train_raybundles) diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index f5fa88f94b..c8f60e24df 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -31,7 +31,6 @@ from nerfstudio.cameras.cameras import Cameras from nerfstudio.cameras.rays import RayBundle from nerfstudio.data.datasets.base_dataset import InputDataset -from nerfstudio.data.datamanagers.base_datamanager import DataManagerConfig from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import get_dict_to_torch @@ -152,7 +151,6 @@ class RayBatchStream(torch.utils.data.IterableDataset): def __init__( self, input_dataset: Dataset, - datamanager_config : DataManagerConfig, num_images_to_sample_from: int = -1, device: Union[torch.device, str] = "cpu", collate_fn: Callable[[Any], Any] = nerfstudio_collate, @@ -177,7 +175,6 @@ def __init__( self.num_image_load_threads = num_image_load_threads #kwargs.get("num_workers", 4) # nb only 4 in defaults self.exclude_batch_keys_from_device = exclude_batch_keys_from_device - self.datamanager_config = datamanager_config self.pixel_sampler = None self.ray_generator = None self._cached_collated_batch = None @@ -216,28 +213,28 @@ def _get_batch_list(self, indices=None): batch_list.append(res.result()) return batch_list - def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: - """copy-pasta from VanillaDataManager.""" - from nerfstudio.cameras.cameras import Cameras, CameraType - from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig - - if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: - return PatchPixelSamplerConfig().setup( - patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch - ) - is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() - if is_equirectangular.any(): - CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") - - fisheye_crop_radius = None - if dataset.cameras.metadata is not None: - fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") - - return self.datamanager_config.pixel_sampler.setup( - is_equirectangular=is_equirectangular, - num_rays_per_batch=num_rays_per_batch, - fisheye_crop_radius=fisheye_crop_radius, - ) + # def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: + # """copy-pasta from VanillaDataManager.""" + # from nerfstudio.cameras.cameras import Cameras, CameraType + # from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig + + # if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: + # return PatchPixelSamplerConfig().setup( + # patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch + # ) + # is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() + # if is_equirectangular.any(): + # CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") + + # fisheye_crop_radius = None + # if dataset.cameras.metadata is not None: + # fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") + + # return self.datamanager_config.pixel_sampler.setup( + # is_equirectangular=is_equirectangular, + # num_rays_per_batch=num_rays_per_batch, + # fisheye_crop_radius=fisheye_crop_radius, + # ) def _get_collated_batch(self, indices=None): """Returns a collated batch.""" From f2bd96fca1e6c5c7aaa3b7aa23543ec72cb79351 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Wed, 12 Jun 2024 17:22:17 -0700 Subject: [PATCH 04/78] working version of RayBatchStream --- .../data/datamanagers/base_datamanager.py | 169 +++++++++- nerfstudio/data/utils/dataloaders.py | 293 +++++++++--------- 2 files changed, 311 insertions(+), 151 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 6d4c687713..9a6486245b 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -56,7 +56,7 @@ from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader, RayBatchStream +from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader#, RayBatchStream from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.model_components.ray_generators import RayGenerator @@ -358,6 +358,169 @@ def __post_init__(self): TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) +import multiprocessing +from torch.utils.data import Dataset +from typing import Sized +import random +import concurrent.futures +from nerfstudio.utils.misc import get_dict_to_torch +class RayBatchStream(torch.utils.data.IterableDataset): + def __init__( + self, + input_dataset: Dataset, + datamanager_config : DataManagerConfig = None, + num_images_to_sample_from: int = -1, + device: Union[torch.device, str] = "cpu", + collate_fn: Callable[[Any], Any] = nerfstudio_collate, + exclude_batch_keys_from_device: Optional[List[str]] = None, + num_image_load_threads : int = 2, + cache_all_n_shard_per_worker : bool = True, + ): + if exclude_batch_keys_from_device is None: + exclude_batch_keys_from_device = ["image"] + self.input_dataset = input_dataset + assert isinstance(self.input_dataset, Sized) + + # super().__init__(dataset=dataset, **kwargs) # This will set self.dataset + + # self.num_times_to_repeat_images = num_times_to_repeat_images + # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) + # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from + self.num_images_to_sample_from = num_images_to_sample_from + self.device = device + self.collate_fn = collate_fn + # self.num_workers = kwargs.get("num_workers", 32) # nb only 4 in defaults + self.num_image_load_threads = num_image_load_threads #kwargs.get("num_workers", 4) # nb only 4 in defaults + self.exclude_batch_keys_from_device = exclude_batch_keys_from_device + + self.datamanager_config = datamanager_config + self.pixel_sampler = None + self.ray_generator = None + self._cached_collated_batch = None + self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker + + def _get_batch_list(self, indices=None): + """Returns a list of batches from the dataset attribute.""" + + assert isinstance(self.input_dataset, Sized) + if indices is None: + indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) + # indices = range(len(self.input_dataset)) + batch_list = [] + results = [] + + # num_threads = int(self.num_ds_load_threads) * 4 + num_threads = ( + int(self.num_image_load_threads) if not self.cache_all_n_shard_per_worker + else 4 * int(self.num_image_load_threads) + ) + num_threads = min(num_threads, multiprocessing.cpu_count() - 1) + num_threads = max(num_threads, 1) + # print('num_threads', num_threads) + + # NB: this is I/O heavy, hence multi-threaded inside the worker + from tqdm.auto import tqdm + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for idx in indices: + res = executor.submit(self.input_dataset.__getitem__, idx) + results.append(res) + + # for res in track(results, description="Loading data batch", transient=True): + # for res in tqdm(results, desc='_get_batch_list'): + if self.cache_all_n_shard_per_worker: + results = tqdm(results) + for res in results: + breakpoint() + batch_list.append(res.result()) + return batch_list + + def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: + """copy-pasta from VanillaDataManager.""" + from nerfstudio.cameras.cameras import Cameras, CameraType + from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig + + if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: + return PatchPixelSamplerConfig().setup( + patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch + ) + is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() + if is_equirectangular.any(): + CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") + + fisheye_crop_radius = None + if dataset.cameras.metadata is not None: + fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") + + return self.datamanager_config.pixel_sampler.setup( + is_equirectangular=is_equirectangular, + num_rays_per_batch=num_rays_per_batch, + fisheye_crop_radius=fisheye_crop_radius, + ) + + def _get_collated_batch(self, indices=None): + """Returns a collated batch.""" + batch_list = self._get_batch_list(indices=indices) + # print('running collate_fn', self.collate_fn) + collated_batch = self.collate_fn(batch_list) + # print('done collate_fn') + # assert False, (self.exclude_batch_keys_from_device, collated_batch) + collated_batch = get_dict_to_torch( + collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device + ) + # print('done get_dict_to_torch') + # print('_get_collated_batch') + return collated_batch + + def __iter__(self): + # Set up stuff now that we're in the worker process + if self.cache_all_n_shard_per_worker: + this_indices = list(range(len(self.input_dataset))) + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + print('TODO log. only single worker not sharding!') + worker_id = -1 + else: + # assign this worker a deterministic uniformly sampled slice + # of the dataset + import math + per_worker = int(math.ceil(len(this_indices) / float(worker_info.num_workers))) + r = random.Random(1337) + r.shuffle(this_indices) + worker_id = worker_info.id + slice_start = worker_id * per_worker + this_indices = this_indices[slice_start:slice_start+per_worker] + print(f'Worker ID {worker_id} working on {len(this_indices)} indices') + + import time + start = time.time() + print(f"Worker ID {worker_id} caching collated batch ...") + self._cached_collated_batch = self._get_collated_batch(indices=this_indices) + print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") + + if self.pixel_sampler is None: + self.pixel_sampler = self._get_pixel_sampler( + self.input_dataset, + self.datamanager_config.train_num_rays_per_batch) + if self.ray_generator is None: + self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) + + # if self._cached_collated_batch is None: + # self._cached_collated_batch = self._get_collated_batch() + # print('did _cached_collated_batch') + while True: + if self._cached_collated_batch is None: + collated_batch = self._get_collated_batch() + else: + collated_batch = self._cached_collated_batch + # batch = self.pixel_sampler.sample(self._cached_collated_batch) + batch = self.pixel_sampler.sample(collated_batch) + ray_indices = batch["indices"] + ray_bundle = self.ray_generator(ray_indices) + yield ray_bundle, batch + + +def identity(x): + return x class VanillaDataManager(DataManager, Generic[TDataset]): """Basic stored data manager implementation. @@ -506,7 +669,7 @@ def setup_train(self): mp.set_start_method('spawn') self.raybatch_stream = RayBatchStream( input_dataset=self.train_dataset, - #self.config, + datamanager_config=self.config, # self.train_pixel_sampler, # self.train_ray_generator, num_images_to_sample_from=self.config.train_num_images_to_sample_from, @@ -524,7 +687,7 @@ def setup_train(self): shuffle=False, pin_memory=False, # Our dataset does batching / collation - collate_fn=lambda x: x, + collate_fn=identity, # pin_memory_device=self.device ) self.iter_train_image_dataloader = None diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index c8f60e24df..3a1d744f1d 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -147,158 +147,155 @@ def __iter__(self): yield collated_batch -class RayBatchStream(torch.utils.data.IterableDataset): - def __init__( - self, - input_dataset: Dataset, - num_images_to_sample_from: int = -1, - device: Union[torch.device, str] = "cpu", - collate_fn: Callable[[Any], Any] = nerfstudio_collate, - exclude_batch_keys_from_device: Optional[List[str]] = None, - num_image_load_threads : int = 2, - cache_all_n_shard_per_worker : bool = True, - ): - if exclude_batch_keys_from_device is None: - exclude_batch_keys_from_device = ["image"] - self.input_dataset = input_dataset - assert isinstance(self.input_dataset, Sized) - - # super().__init__(dataset=dataset, **kwargs) # This will set self.dataset - - # self.num_times_to_repeat_images = num_times_to_repeat_images - # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) - # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from - self.num_images_to_sample_from = num_images_to_sample_from - self.device = device - self.collate_fn = collate_fn - # self.num_workers = kwargs.get("num_workers", 32) # nb only 4 in defaults - self.num_image_load_threads = num_image_load_threads #kwargs.get("num_workers", 4) # nb only 4 in defaults - self.exclude_batch_keys_from_device = exclude_batch_keys_from_device - - self.pixel_sampler = None - self.ray_generator = None - self._cached_collated_batch = None - self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker +# class RayBatchStream(torch.utils.data.IterableDataset): +# def __init__( +# self, +# input_dataset: Dataset, +# num_images_to_sample_from: int = -1, +# device: Union[torch.device, str] = "cpu", +# collate_fn: Callable[[Any], Any] = nerfstudio_collate, +# exclude_batch_keys_from_device: Optional[List[str]] = None, +# num_image_load_threads : int = 2, +# cache_all_n_shard_per_worker : bool = True, +# ): +# if exclude_batch_keys_from_device is None: +# exclude_batch_keys_from_device = ["image"] +# self.input_dataset = input_dataset +# assert isinstance(self.input_dataset, Sized) + +# # super().__init__(dataset=dataset, **kwargs) # This will set self.dataset + +# # self.num_times_to_repeat_images = num_times_to_repeat_images +# # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) +# # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from +# self.num_images_to_sample_from = num_images_to_sample_from +# self.device = device +# self.collate_fn = collate_fn +# # self.num_workers = kwargs.get("num_workers", 32) # nb only 4 in defaults +# self.num_image_load_threads = num_image_load_threads #kwargs.get("num_workers", 4) # nb only 4 in defaults +# self.exclude_batch_keys_from_device = exclude_batch_keys_from_device + +# self.pixel_sampler = None +# self.ray_generator = None +# self._cached_collated_batch = None +# self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker - def _get_batch_list(self, indices=None): - """Returns a list of batches from the dataset attribute.""" - - assert isinstance(self.input_dataset, Sized) - if not indices: - indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) - # indices = range(len(self.input_dataset)) - batch_list = [] - results = [] - - # num_threads = int(self.num_ds_load_threads) * 4 - num_threads = ( - int(self.num_image_load_threads) if not self.cache_all_n_shard_per_worker - else 4 * int(self.num_image_load_threads)) - num_threads = min(num_threads, multiprocessing.cpu_count() - 1) - num_threads = max(num_threads, 1) - # print('num_threads', num_threads) - - # NB: this is I/O heavy, hence multi-threaded inside the worker - from tqdm.auto import tqdm - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - for idx in indices: - res = executor.submit(self.input_dataset.__getitem__, idx) - results.append(res) - - # for res in track(results, description="Loading data batch", transient=True): - # for res in tqdm(results, desc='_get_batch_list'): - if self.cache_all_n_shard_per_worker: - results = tqdm(results) - for res in results: - batch_list.append(res.result()) - return batch_list +# def _get_batch_list(self, indices=None): +# """Returns a list of batches from the dataset attribute.""" + +# assert isinstance(self.input_dataset, Sized) +# if not indices: +# indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) +# # indices = range(len(self.input_dataset)) +# batch_list = [] +# results = [] + +# # num_threads = int(self.num_ds_load_threads) * 4 +# num_threads = ( +# int(self.num_image_load_threads) if not self.cache_all_n_shard_per_worker +# else 4 * int(self.num_image_load_threads)) +# num_threads = min(num_threads, multiprocessing.cpu_count() - 1) +# num_threads = max(num_threads, 1) +# # print('num_threads', num_threads) + +# # NB: this is I/O heavy, hence multi-threaded inside the worker +# from tqdm.auto import tqdm +# with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: +# for idx in indices: +# res = executor.submit(self.input_dataset.__getitem__, idx) +# results.append(res) + +# # for res in track(results, description="Loading data batch", transient=True): +# # for res in tqdm(results, desc='_get_batch_list'): +# if self.cache_all_n_shard_per_worker: +# results = tqdm(results) +# for res in results: +# batch_list.append(res.result()) +# return batch_list - # def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: - # """copy-pasta from VanillaDataManager.""" - # from nerfstudio.cameras.cameras import Cameras, CameraType - # from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig - - # if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: - # return PatchPixelSamplerConfig().setup( - # patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch - # ) - # is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() - # if is_equirectangular.any(): - # CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") - - # fisheye_crop_radius = None - # if dataset.cameras.metadata is not None: - # fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") - - # return self.datamanager_config.pixel_sampler.setup( - # is_equirectangular=is_equirectangular, - # num_rays_per_batch=num_rays_per_batch, - # fisheye_crop_radius=fisheye_crop_radius, - # ) +# # def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: +# # """copy-pasta from VanillaDataManager.""" +# # from nerfstudio.cameras.cameras import Cameras, CameraType +# # from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig + +# # if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: +# # return PatchPixelSamplerConfig().setup( +# # patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch +# # ) +# # is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() +# # if is_equirectangular.any(): +# # CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") + +# # fisheye_crop_radius = None +# # if dataset.cameras.metadata is not None: +# # fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") + +# # return self.datamanager_config.pixel_sampler.setup( +# # is_equirectangular=is_equirectangular, +# # num_rays_per_batch=num_rays_per_batch, +# # fisheye_crop_radius=fisheye_crop_radius, +# # ) - def _get_collated_batch(self, indices=None): - """Returns a collated batch.""" - batch_list = self._get_batch_list(indices=indices) - # print('running collate_fn', self.collate_fn) - collated_batch = self.collate_fn(batch_list) - # print('done collate_fn') - # assert False, (self.exclude_batch_keys_from_device, collated_batch) - collated_batch = get_dict_to_torch( - collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device - ) - # print('done get_dict_to_torch') - # print('_get_collated_batch') - return collated_batch - - def __iter__(self): - # Set up stuff now that we're in the worker process - if self.cache_all_n_shard_per_worker: - this_indices = list(range(len(self.input_dataset))) - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - print('TODO log. only single worker not sharding!') - worker_id = -1 - else: - # assign this worker a deterministic uniformly sampled slice - # of the dataset - import math - per_worker = int( - math.ceil(len(this_indices) / float(worker_info.num_workers))) - r = random.Random(1337) - r.shuffle(this_indices) - worker_id = worker_info.id - slice_start = worker_id * per_worker - this_indices = this_indices[slice_start:slice_start+per_worker] - print( - f'Worker ID {worker_id} working on {len(this_indices)} indices') +# def _get_collated_batch(self, indices=None): +# """Returns a collated batch.""" +# batch_list = self._get_batch_list(indices=indices) +# # print('running collate_fn', self.collate_fn) +# collated_batch = self.collate_fn(batch_list) +# # print('done collate_fn') +# # assert False, (self.exclude_batch_keys_from_device, collated_batch) +# collated_batch = get_dict_to_torch( +# collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device +# ) +# # print('done get_dict_to_torch') +# # print('_get_collated_batch') +# return collated_batch + +# def __iter__(self): +# # Set up stuff now that we're in the worker process +# if self.cache_all_n_shard_per_worker: +# this_indices = list(range(len(self.input_dataset))) +# worker_info = torch.utils.data.get_worker_info() +# if worker_info is None: +# print('TODO log. only single worker not sharding!') +# worker_id = -1 +# else: +# # assign this worker a deterministic uniformly sampled slice +# # of the dataset +# import math +# per_worker = int(math.ceil(len(this_indices) / float(worker_info.num_workers))) +# r = random.Random(1337) +# r.shuffle(this_indices) +# worker_id = worker_info.id +# slice_start = worker_id * per_worker +# this_indices = this_indices[slice_start:slice_start+per_worker] +# print(f'Worker ID {worker_id} working on {len(this_indices)} indices') - import time - start = time.time() - print(f"Worker ID {worker_id} caching collated batch ...") - self._cached_collated_batch = self._get_collated_batch( - indices=this_indices) - print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") - - if self.pixel_sampler is None: - self.pixel_sampler = self._get_pixel_sampler( - self.input_dataset, - self.datamanager_config.train_num_rays_per_batch) - if self.ray_generator is None: - self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) - - # if self._cached_collated_batch is None: - # self._cached_collated_batch = self._get_collated_batch() - # print('did _cached_collated_batch') - while True: - if self._cached_collated_batch is None: - collated_batch = self._get_collated_batch() - else: - collated_batch = self._cached_collated_batch - # batch = self.pixel_sampler.sample(self._cached_collated_batch) - batch = self.pixel_sampler.sample(collated_batch) - ray_indices = batch["indices"] - ray_bundle = self.ray_generator(ray_indices) - yield ray_bundle, batch +# import time +# start = time.time() +# print(f"Worker ID {worker_id} caching collated batch ...") +# self._cached_collated_batch = self._get_collated_batch(indices=this_indices) +# print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") + +# if self.pixel_sampler is None: +# self.pixel_sampler = self._get_pixel_sampler( +# self.input_dataset, +# self.datamanager_config.train_num_rays_per_batch) +# if self.ray_generator is None: +# self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) + +# # if self._cached_collated_batch is None: +# # self._cached_collated_batch = self._get_collated_batch() +# # print('did _cached_collated_batch') +# while True: +# if self._cached_collated_batch is None: +# collated_batch = self._get_collated_batch() +# else: +# collated_batch = self._cached_collated_batch +# # batch = self.pixel_sampler.sample(self._cached_collated_batch) +# batch = self.pixel_sampler.sample(collated_batch) +# ray_indices = batch["indices"] +# ray_bundle = self.ray_generator(ray_indices) +# yield ray_bundle, batch class EvalDataloader(DataLoader): From d8b7430be482638c53ffb25c7fd99b777a9b373b Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Wed, 12 Jun 2024 18:24:41 -0700 Subject: [PATCH 05/78] additional docstrings --- .../data/datamanagers/base_datamanager.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 9a6486245b..d2275b70d4 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -336,8 +336,14 @@ class VanillaDataManagerConfig(DataManagerConfig): patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" dataloader_prefetch_size : int = 2 + """The limit number of batches a worker will start loading once an iterator is created. + Each next() call on the iterator has the CPU prepare more batches up to this + limit while the GPU is performing forward and backward passes on the model.""" dataloader_num_workers : int = 16 + """The number of workers performing the dataloading from either disk/RAM, which + includes undistortion, pixel sampling, ray generation, collating, etc.""" use_ray_train_dataloader: bool = True + """Allows parallelization of the dataloading process with multiple workers.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -381,7 +387,6 @@ def __init__( self.input_dataset = input_dataset assert isinstance(self.input_dataset, Sized) - # super().__init__(dataset=dataset, **kwargs) # This will set self.dataset # self.num_times_to_repeat_images = num_times_to_repeat_images # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) @@ -397,6 +402,7 @@ def __init__( self.pixel_sampler = None self.ray_generator = None self._cached_collated_batch = None + """""" self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker def _get_batch_list(self, indices=None): @@ -418,7 +424,8 @@ def _get_batch_list(self, indices=None): num_threads = max(num_threads, 1) # print('num_threads', num_threads) - # NB: this is I/O heavy, hence multi-threaded inside the worker + # NB: this is I/O heavy because we are going to disk and reading an image filename + # hence multi-threaded inside the worker from tqdm.auto import tqdm with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: for idx in indices: @@ -430,8 +437,8 @@ def _get_batch_list(self, indices=None): if self.cache_all_n_shard_per_worker: results = tqdm(results) for res in results: - breakpoint() batch_list.append(res.result()) + print(batch_list) return batch_list def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: @@ -650,19 +657,6 @@ def setup_train(self): """Sets up the data loaders for training""" assert self.train_dataset is not None CONSOLE.print("Setting up training dataset...") - # self.train_image_dataloader = CacheDataloader( - # self.train_dataset, - # num_images_to_sample_from=self.config.train_num_images_to_sample_from, - # num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, - # device=self.device, - # num_workers=self.world_size * 4, - # pin_memory=True, - # collate_fn=self.config.collate_fn, - # exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, - # ) - # self.iter_train_image_dataloader = iter(self.train_image_dataloader) - # self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) - # self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device)) if self.config.use_ray_train_dataloader: import torch.multiprocessing as mp From a5425d40fdd935b7ad9ac0954fe7959dc242f917 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Wed, 12 Jun 2024 18:37:49 -0700 Subject: [PATCH 06/78] cleanup --- .../data/datamanagers/base_datamanager.py | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index d2275b70d4..d5093e804e 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -387,8 +387,6 @@ def __init__( self.input_dataset = input_dataset assert isinstance(self.input_dataset, Sized) - - # self.num_times_to_repeat_images = num_times_to_repeat_images # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from self.num_images_to_sample_from = num_images_to_sample_from @@ -405,13 +403,35 @@ def __init__( """""" self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker + def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: + """copy-pasta from VanillaDataManager.""" + from nerfstudio.cameras.cameras import Cameras, CameraType + from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig + + if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: + return PatchPixelSamplerConfig().setup( + patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch + ) + is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() + if is_equirectangular.any(): + CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") + + fisheye_crop_radius = None + if dataset.cameras.metadata is not None: + fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") + + return self.datamanager_config.pixel_sampler.setup( + is_equirectangular=is_equirectangular, + num_rays_per_batch=num_rays_per_batch, + fisheye_crop_radius=fisheye_crop_radius, + ) + def _get_batch_list(self, indices=None): """Returns a list of batches from the dataset attribute.""" assert isinstance(self.input_dataset, Sized) if indices is None: indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) - # indices = range(len(self.input_dataset)) batch_list = [] results = [] @@ -438,32 +458,8 @@ def _get_batch_list(self, indices=None): results = tqdm(results) for res in results: batch_list.append(res.result()) - print(batch_list) return batch_list - def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: - """copy-pasta from VanillaDataManager.""" - from nerfstudio.cameras.cameras import Cameras, CameraType - from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig - - if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: - return PatchPixelSamplerConfig().setup( - patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch - ) - is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() - if is_equirectangular.any(): - CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") - - fisheye_crop_radius = None - if dataset.cameras.metadata is not None: - fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") - - return self.datamanager_config.pixel_sampler.setup( - is_equirectangular=is_equirectangular, - num_rays_per_batch=num_rays_per_batch, - fisheye_crop_radius=fisheye_crop_radius, - ) - def _get_collated_batch(self, indices=None): """Returns a collated batch.""" batch_list = self._get_batch_list(indices=indices) From 604f7341ceff6d963c1e80a33556cb66a87d36b0 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Wed, 12 Jun 2024 21:41:56 -0700 Subject: [PATCH 07/78] much more documentation --- .../data/datamanagers/base_datamanager.py | 49 ++++++++++++------- .../datamanagers/full_images_datamanager.py | 2 +- .../scripts/datasets/process_project_aria.py | 1 + 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index d5093e804e..e6974040d7 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -339,7 +339,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """The limit number of batches a worker will start loading once an iterator is created. Each next() call on the iterator has the CPU prepare more batches up to this limit while the GPU is performing forward and backward passes on the model.""" - dataloader_num_workers : int = 16 + dataloader_num_workers : int = 2 """The number of workers performing the dataloading from either disk/RAM, which includes undistortion, pixel sampling, ray generation, collating, etc.""" use_ray_train_dataloader: bool = True @@ -400,8 +400,9 @@ def __init__( self.pixel_sampler = None self.ray_generator = None self._cached_collated_batch = None - """""" + """_cached_collated_batch contains a collated batch of images that's ready for pixel sampling. I""" self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker + """If True, _cached_collated_batch is populated with a subset of the dataset assigned to each worker during the iteration process.""" def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: """copy-pasta from VanillaDataManager.""" @@ -427,7 +428,10 @@ def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> Pi ) def _get_batch_list(self, indices=None): - """Returns a list of batches from the dataset attribute.""" + """Returns a list representing a single batch from the dataset attribute. + Each item of the list is a dictionary with dict_keys(['image_idx', 'image']) representing 1 image. + This function is used to sample and load images from disk/RAM and is only called in _get_collated_batch + The length of the list is equal to the (# of training images) / (num_workers)""" assert isinstance(self.input_dataset, Sized) if indices is None: @@ -461,28 +465,38 @@ def _get_batch_list(self, indices=None): return batch_list def _get_collated_batch(self, indices=None): - """Returns a collated batch.""" + """Takes the output of _get_batch_list and collates them with nerfstudio_collate() + Note: dict is an instance of collections.abc.Mapping + + The resulting output is collated_batch: a dictionary with dict_keys(['image_idx', 'image']) + collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) + collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) + """ batch_list = self._get_batch_list(indices=indices) - # print('running collate_fn', self.collate_fn) + # print(type(batch_list[0])) # prints + # print(self.collate_fn) # prints nerfstudio_collate collated_batch = self.collate_fn(batch_list) - # print('done collate_fn') - # assert False, (self.exclude_batch_keys_from_device, collated_batch) + # collated_batch is a dictionary with dict_keys(['image_idx', 'image']) + # collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) + # collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) + #print(collated_batch['image_idx'].shape) + #print(collated_batch['image'].shape) collated_batch = get_dict_to_torch( collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device ) - # print('done get_dict_to_torch') - # print('_get_collated_batch') return collated_batch def __iter__(self): # Set up stuff now that we're in the worker process if self.cache_all_n_shard_per_worker: - this_indices = list(range(len(self.input_dataset))) + this_indices = list(range(len(self.input_dataset))) + # this_indices has len 300, at first it is the whole training dataset, but it gets partitioned into equal chunks worker_info = torch.utils.data.get_worker_info() if worker_info is None: print('TODO log. only single worker not sharding!') worker_id = -1 else: + # Here, we are in the worker process now # assign this worker a deterministic uniformly sampled slice # of the dataset import math @@ -491,7 +505,7 @@ def __iter__(self): r.shuffle(this_indices) worker_id = worker_info.id slice_start = worker_id * per_worker - this_indices = this_indices[slice_start:slice_start+per_worker] + this_indices = this_indices[slice_start:slice_start+per_worker] print(f'Worker ID {worker_id} working on {len(this_indices)} indices') import time @@ -501,21 +515,18 @@ def __iter__(self): print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") if self.pixel_sampler is None: - self.pixel_sampler = self._get_pixel_sampler( - self.input_dataset, - self.datamanager_config.train_num_rays_per_batch) + self.pixel_sampler = self._get_pixel_sampler( + self.input_dataset, + self.datamanager_config.train_num_rays_per_batch + ) if self.ray_generator is None: - self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) + self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) - # if self._cached_collated_batch is None: - # self._cached_collated_batch = self._get_collated_batch() - # print('did _cached_collated_batch') while True: if self._cached_collated_batch is None: collated_batch = self._get_collated_batch() else: collated_batch = self._cached_collated_batch - # batch = self.pixel_sampler.sample(self._cached_collated_batch) batch = self.pixel_sampler.sample(collated_batch) ray_indices = batch["indices"] ray_bundle = self.ray_generator(ray_indices) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index a67d492992..6817ebe52a 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -79,7 +79,7 @@ class FullImageDatamanagerConfig(DataManagerConfig): fps_reset_every: int = 100 """The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every samples from the pool of all training cameras without replacement before a new round of sampling starts.""" - + class FullImageDatamanager(DataManager, Generic[TDataset]): """ diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index 4ebda7e4e8..a6df1f3396 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -415,6 +415,7 @@ def main(self) -> None: print("No global points found!") # Write the json out to disk as transforms.json + print(len(nerfstudio_frames['frames'])) print("Writing transforms.json") transform_file = self.output_dir / "transforms.json" with open(transform_file, "w", encoding="UTF-8"): From 0143803778e5c8a25f64e00031f2653ce47bf985 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 13 Jun 2024 05:04:10 -0700 Subject: [PATCH 08/78] successfully trained AEA-script2_seq2 closed_loop without OOM --- .../data/datamanagers/base_datamanager.py | 26 +++++++++---------- .../datamanagers/full_images_datamanager.py | 8 ++++++ nerfstudio/data/utils/nerfstudio_collate.py | 2 ++ 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index e6974040d7..d2394fb710 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -313,7 +313,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """Specifies the dataparser used to unpack the data.""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" - train_num_images_to_sample_from: int = -1 + train_num_images_to_sample_from: int = -1 # was -1 """Number of images to sample during training iteration.""" train_num_times_to_repeat_images: int = -1 """When not training on all images, number of iterations before picking new @@ -339,7 +339,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """The limit number of batches a worker will start loading once an iterator is created. Each next() call on the iterator has the CPU prepare more batches up to this limit while the GPU is performing forward and backward passes on the model.""" - dataloader_num_workers : int = 2 + dataloader_num_workers : int = 8 """The number of workers performing the dataloading from either disk/RAM, which includes undistortion, pixel sampling, ray generation, collating, etc.""" use_ray_train_dataloader: bool = True @@ -365,6 +365,7 @@ def __post_init__(self): TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) import multiprocessing +import math from torch.utils.data import Dataset from typing import Sized import random @@ -380,7 +381,7 @@ def __init__( collate_fn: Callable[[Any], Any] = nerfstudio_collate, exclude_batch_keys_from_device: Optional[List[str]] = None, num_image_load_threads : int = 2, - cache_all_n_shard_per_worker : bool = True, + cache_all_n_shard_per_worker : bool = True, # When False, always getting Killed/bugs for some reason ): if exclude_batch_keys_from_device is None: exclude_batch_keys_from_device = ["image"] @@ -476,11 +477,6 @@ def _get_collated_batch(self, indices=None): # print(type(batch_list[0])) # prints # print(self.collate_fn) # prints nerfstudio_collate collated_batch = self.collate_fn(batch_list) - # collated_batch is a dictionary with dict_keys(['image_idx', 'image']) - # collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) - # collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) - #print(collated_batch['image_idx'].shape) - #print(collated_batch['image'].shape) collated_batch = get_dict_to_torch( collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device ) @@ -488,18 +484,16 @@ def _get_collated_batch(self, indices=None): def __iter__(self): # Set up stuff now that we're in the worker process + this_indices = list(range(len(self.input_dataset))) + worker_info = torch.utils.data.get_worker_info() if self.cache_all_n_shard_per_worker: - this_indices = list(range(len(self.input_dataset))) # this_indices has len 300, at first it is the whole training dataset, but it gets partitioned into equal chunks - worker_info = torch.utils.data.get_worker_info() if worker_info is None: print('TODO log. only single worker not sharding!') worker_id = -1 else: - # Here, we are in the worker process now # assign this worker a deterministic uniformly sampled slice # of the dataset - import math per_worker = int(math.ceil(len(this_indices) / float(worker_info.num_workers))) r = random.Random(1337) r.shuffle(this_indices) @@ -524,7 +518,13 @@ def __iter__(self): while True: if self._cached_collated_batch is None: - collated_batch = self._get_collated_batch() + per_worker = int(math.ceil(len(this_indices) / float(worker_info.num_workers))) + r = random.Random(1337) + r.shuffle(this_indices) + worker_id = worker_info.id + slice_start = worker_id * per_worker + this_indices = this_indices[slice_start:slice_start+per_worker] + collated_batch = self._get_collated_batch(this_indices) else: collated_batch = self._cached_collated_batch batch = self.pixel_sampler.sample(collated_batch) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 6817ebe52a..0ce84e4b77 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -47,6 +47,14 @@ from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE +class ImageBatchStream(torch.utils.data.IterableDataset): + def __init__( + self, + + ): + return + + # def @dataclass class FullImageDatamanagerConfig(DataManagerConfig): diff --git a/nerfstudio/data/utils/nerfstudio_collate.py b/nerfstudio/data/utils/nerfstudio_collate.py index 8c8a633fb8..9a859a7fec 100644 --- a/nerfstudio/data/utils/nerfstudio_collate.py +++ b/nerfstudio/data/utils/nerfstudio_collate.py @@ -98,6 +98,8 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) + import warnings + warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') storage = elem.storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) From d3527e25725dcd77604d217efe349d023cc71963 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 13 Jun 2024 05:11:23 -0700 Subject: [PATCH 09/78] porting over aria dataset-size feature --- .../scripts/datasets/process_project_aria.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index a6df1f3396..54998a6564 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -13,12 +13,13 @@ # limitations under the License. import json +import random import sys import threading from dataclasses import dataclass from itertools import zip_longest from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, cast +from typing import Any, Dict, List, Literal, Tuple, cast import numpy as np import open3d as o3d @@ -228,7 +229,7 @@ def to_aria_image_frame( # Compute the world to camera transform. t_world_camera = t_world_device @ src_calib.get_transform_device_camera() @ T_ARIA_NERFSTUDIO - # Define new AriaCameraCalibration since we rotated the image + # Define new AriaCameraCalibration since we rotated the image to be upright width = src_calib.get_image_size()[0].item() height = src_calib.get_image_size()[1].item() intrinsics = src_calib.projection_params() @@ -293,10 +294,14 @@ class ProcessProjectAria: """Path to Project Aria Machine Perception Services (MPS) attachments.""" output_dir: Path """Path to the output directory.""" - points_file: Optional[Tuple[Path, ...]] = () + points_file: Tuple[Path, ...] = () """Path to the point cloud file (usually called semidense_points.csv.gz) if not in the mps_data_dir""" include_side_cameras: bool = False - """If True, include and process the images captured by the grayscale side cameras. If False, only uses the main RGB camera's data.""" + """If True, include and process the images captured by the grayscale side cameras. + If False, only uses the main RGB camera's data.""" + max_dataset_size: int = 600 + """Max number of images to train on. If the provided vrs_file has more images than max_dataset_size, + images will be sampled approximately evenly. If max_dataset_size=-1, use all images available.""" def main(self) -> None: """Generate a nerfstudio dataset from ProjectAria data (VRS) and MPS attachments.""" @@ -308,7 +313,8 @@ def main(self) -> None: assert len(self.vrs_file) == len( self.mps_data_dir ), "Please provide an Aria MPS attachment for each corresponding VRS file." - vrs_mps_points_triplets = list(zip_longest(self.vrs_file, self.mps_data_dir, self.points_file)) # type: ignore + vrs_mps_points_triplets = list(zip_longest(self.vrs_file, self.mps_data_dir, self.points_file)) # type: ignore + num_recordings = len(vrs_mps_points_triplets) nerfstudio_frames = { "camera_model": "OPENCV" if self.include_side_cameras else ARIA_CAMERA_MODEL, "frames": [], @@ -336,12 +342,22 @@ def main(self) -> None: print(f"Creating Aria frames for recording {rec_i + 1}...") CANONICAL_RGB_VALID_RADIUS = 707.5 # radius of a circular mask that represents the valid area on the camera's sensor plane. Pixels out of this circular region are considered invalid CANONICAL_RGB_WIDTH = 1408 + total_num_images_per_camera = provider.get_num_data(stream_ids[0]) + if self.max_dataset_size == -1: + num_images_to_sample_per_camera = total_num_images_per_camera + else: + num_images_to_sample_per_camera = ( + self.max_dataset_size // (len(vrs_mps_points_triplets) * 3) + if self.include_side_cameras + else self.max_dataset_size // len(vrs_mps_points_triplets) + ) + sampling_indicies = random.sample(range(total_num_images_per_camera), num_images_to_sample_per_camera) if not self.include_side_cameras: aria_rgb_frames = [ to_aria_image_frame( provider, index, name_to_camera, t_world_devices, self.output_dir, camera_name=names[0] ) - for index in range(0, provider.get_num_data(stream_ids[0])) + for index in sampling_indicies ] print(f"Creating NerfStudio frames for recording {rec_i + 1}...") nerfstudio_frames["frames"] += [to_nerfstudio_frame(frame) for frame in aria_rgb_frames] @@ -361,7 +377,7 @@ def main(self) -> None: camera_name=names[i], pinhole=True, ) - for index in range(0, provider.get_num_data(stream_id)) + for index in sampling_indicies ] for i, stream_id in enumerate(stream_ids) ] @@ -389,7 +405,7 @@ def main(self) -> None: ] nerfstudio_frames["frames"] += pinhole_frames - if points_file: + if points_file is not None: points_path = points_file else: points_path = mps_data_dir / "global_points.csv.gz" @@ -403,9 +419,10 @@ def main(self) -> None: points_data = mps.read_global_point_cloud(str(points_path)) # type: ignore points_data = filter_points_from_confidence(points_data) points += [cast(Any, it).position_world for it in points_data] - - if points: + print(len(nerfstudio_frames['frames'])) + if len(points) > 0: print("Saving found points to PLY...") + print(f"Total number of points found: {len(points)} in {num_recordings} recording(s) provided") pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(np.array(points)) ply_file_path = self.output_dir / "global_points.ply" @@ -415,7 +432,6 @@ def main(self) -> None: print("No global points found!") # Write the json out to disk as transforms.json - print(len(nerfstudio_frames['frames'])) print("Writing transforms.json") transform_file = self.output_dir / "transforms.json" with open(transform_file, "w", encoding="UTF-8"): From 25f5f27f3826e22d9b19169dd62fcdd999c2089a Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 14 Jun 2024 05:28:28 -0700 Subject: [PATCH 10/78] added logic to handle eviction of a worker's cached_collated_batch --- .../data/datamanagers/base_datamanager.py | 66 ++++++++++--------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index d2394fb710..d3204ea6c1 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -313,7 +313,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """Specifies the dataparser used to unpack the data.""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" - train_num_images_to_sample_from: int = -1 # was -1 + train_num_images_to_sample_from: int = -1 """Number of images to sample during training iteration.""" train_num_times_to_repeat_images: int = -1 """When not training on all images, number of iterations before picking new @@ -335,7 +335,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - dataloader_prefetch_size : int = 2 + dataloader_prefetch_size : int = 8 """The limit number of batches a worker will start loading once an iterator is created. Each next() call on the iterator has the CPU prepare more batches up to this limit while the GPU is performing forward and backward passes on the model.""" @@ -376,12 +376,13 @@ def __init__( self, input_dataset: Dataset, datamanager_config : DataManagerConfig = None, - num_images_to_sample_from: int = -1, + num_images_to_sample_from: int = -1, # passed in from VanillaDataManager device: Union[torch.device, str] = "cpu", collate_fn: Callable[[Any], Any] = nerfstudio_collate, exclude_batch_keys_from_device: Optional[List[str]] = None, - num_image_load_threads : int = 2, - cache_all_n_shard_per_worker : bool = True, # When False, always getting Killed/bugs for some reason + num_image_load_threads : int = 4, + cache_all_n_shard_per_worker : bool = True, # When False, always getting Killed/bugs for some reason... why? + # when cache_all_n_shard_per_worker True, getting killed because caching everything is not good ): if exclude_batch_keys_from_device is None: exclude_batch_keys_from_device = ["image"] @@ -401,7 +402,7 @@ def __init__( self.pixel_sampler = None self.ray_generator = None self._cached_collated_batch = None - """_cached_collated_batch contains a collated batch of images that's ready for pixel sampling. I""" + """_cached_collated_batch contains a collated batch of images for a specific worker that's ready for pixel sampling.""" self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker """If True, _cached_collated_batch is populated with a subset of the dataset assigned to each worker during the iteration process.""" @@ -473,7 +474,9 @@ def _get_collated_batch(self, indices=None): collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) """ - batch_list = self._get_batch_list(indices=indices) + batch_list=self._get_batch_list(indices=indices) + # if len(batch_list) == 0: + # print(indices) # print(type(batch_list[0])) # prints # print(self.collate_fn) # prints nerfstudio_collate collated_batch = self.collate_fn(batch_list) @@ -483,29 +486,29 @@ def _get_collated_batch(self, indices=None): return collated_batch def __iter__(self): + """Defines the iterator for the dataset.""" # Set up stuff now that we're in the worker process - this_indices = list(range(len(self.input_dataset))) + dataset_indices = list(range(len(self.input_dataset))) # this_indices has len 300, at first it is the whole training dataset, but it gets partitioned into equal chunks worker_info = torch.utils.data.get_worker_info() - if self.cache_all_n_shard_per_worker: - # this_indices has len 300, at first it is the whole training dataset, but it gets partitioned into equal chunks + if self.cache_all_n_shard_per_worker: # if we want every worker to cache their partition if worker_info is None: print('TODO log. only single worker not sharding!') worker_id = -1 else: # assign this worker a deterministic uniformly sampled slice # of the dataset - per_worker = int(math.ceil(len(this_indices) / float(worker_info.num_workers))) + per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) r = random.Random(1337) - r.shuffle(this_indices) + r.shuffle(dataset_indices) worker_id = worker_info.id slice_start = worker_id * per_worker - this_indices = this_indices[slice_start:slice_start+per_worker] - print(f'Worker ID {worker_id} working on {len(this_indices)} indices') + worker_indices = dataset_indices[slice_start:slice_start+per_worker] + print(f'Worker ID {worker_id} working on {len(worker_indices)} indices') import time start = time.time() print(f"Worker ID {worker_id} caching collated batch ...") - self._cached_collated_batch = self._get_collated_batch(indices=this_indices) + self._cached_collated_batch = self._get_collated_batch(indices=worker_indices) print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") if self.pixel_sampler is None: @@ -515,18 +518,23 @@ def __iter__(self): ) if self.ray_generator is None: self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) - - while True: - if self._cached_collated_batch is None: - per_worker = int(math.ceil(len(this_indices) / float(worker_info.num_workers))) - r = random.Random(1337) - r.shuffle(this_indices) - worker_id = worker_info.id - slice_start = worker_id * per_worker - this_indices = this_indices[slice_start:slice_start+per_worker] - collated_batch = self._get_collated_batch(this_indices) + + # if cache_all_n_shard_per_worker=True, every worker should have a _cached_collated_batch when the iterator was created (above lines) + # falling into this if statement means the worker's cached_collated_batch was evicted or cache_all_n_shard_per_worker=False + if self._cached_collated_batch is None: + if worker_info is None: + per_worker = len(dataset_indices) else: - collated_batch = self._cached_collated_batch + per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) + r = random.Random(1337) + r.shuffle(dataset_indices) + worker_id = 0 if worker_info is None else worker_info.id + slice_start = worker_id * per_worker + worker_indices = dataset_indices[slice_start:slice_start+per_worker] # the indices of the datapoints in the dataset this worker will load + collated_batch = self._get_collated_batch(worker_indices) + else: + collated_batch = self._cached_collated_batch + while True: batch = self.pixel_sampler.sample(collated_batch) ray_indices = batch["indices"] ray_bundle = self.ray_generator(ray_indices) @@ -700,12 +708,10 @@ def setup_train(self): num_images_to_sample_from=self.config.train_num_images_to_sample_from, num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - num_workers= - self.world_size * 4 + num_workers=self.world_size * 4 if self.config.dataloader_num_workers == -1 else self.config.dataloader_num_workers, - prefetch_factor= - 2 + prefetch_factor=2 if self.config.dataloader_prefetch_size == -1 else self.config.dataloader_prefetch_size, pin_memory=True, From 3a8b63b9114c664f0299785d0b478f52287d910a Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 15 Jun 2024 05:35:43 -0700 Subject: [PATCH 11/78] antonio's implementation of stream batches --- .../data/datamanagers/base_datamanager.py | 234 +++++++++++------- .../scripts/datasets/process_project_aria.py | 2 +- 2 files changed, 147 insertions(+), 89 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index d3204ea6c1..7af465cca8 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -56,7 +56,11 @@ from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader#, RayBatchStream +from nerfstudio.data.utils.dataloaders import ( # , RayBatchStream + CacheDataloader, + FixedIndicesEvalDataloader, + RandIndicesEvalDataloader, +) from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.model_components.ray_generators import RayGenerator @@ -335,11 +339,11 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - dataloader_prefetch_size : int = 8 + dataloader_prefetch_size: int = 8 """The limit number of batches a worker will start loading once an iterator is created. Each next() call on the iterator has the CPU prepare more batches up to this limit while the GPU is performing forward and backward passes on the model.""" - dataloader_num_workers : int = 8 + dataloader_num_workers: int = 8 """The number of workers performing the dataloading from either disk/RAM, which includes undistortion, pixel sampling, ray generation, collating, etc.""" use_ray_train_dataloader: bool = True @@ -364,25 +368,29 @@ def __post_init__(self): TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) -import multiprocessing +import concurrent.futures import math -from torch.utils.data import Dataset -from typing import Sized +import multiprocessing import random -import concurrent.futures +from typing import Sized + +from torch.utils.data import Dataset + from nerfstudio.utils.misc import get_dict_to_torch + + class RayBatchStream(torch.utils.data.IterableDataset): def __init__( - self, - input_dataset: Dataset, - datamanager_config : DataManagerConfig = None, - num_images_to_sample_from: int = -1, # passed in from VanillaDataManager - device: Union[torch.device, str] = "cpu", - collate_fn: Callable[[Any], Any] = nerfstudio_collate, - exclude_batch_keys_from_device: Optional[List[str]] = None, - num_image_load_threads : int = 4, - cache_all_n_shard_per_worker : bool = True, # When False, always getting Killed/bugs for some reason... why? - # when cache_all_n_shard_per_worker True, getting killed because caching everything is not good + self, + input_dataset: Dataset, + datamanager_config: DataManagerConfig, + num_images_to_sample_from: int = -1, # passed in from VanillaDataManager + device: Union[torch.device, str] = "cpu", + collate_fn: Callable[[Any], Any] = nerfstudio_collate, + exclude_batch_keys_from_device: Optional[List[str]] = None, + num_image_load_threads: int = 4, + cache_all_n_shard_per_worker: bool = True, # When False, always getting Killed/bugs for some reason... why? + # when cache_all_n_shard_per_worker True, getting killed because caching everything is not good ): if exclude_batch_keys_from_device is None: exclude_batch_keys_from_device = ["image"] @@ -395,21 +403,21 @@ def __init__( self.device = device self.collate_fn = collate_fn # self.num_workers = kwargs.get("num_workers", 32) # nb only 4 in defaults - self.num_image_load_threads = num_image_load_threads #kwargs.get("num_workers", 4) # nb only 4 in defaults + self.num_image_load_threads = num_image_load_threads # kwargs.get("num_workers", 4) # nb only 4 in defaults self.exclude_batch_keys_from_device = exclude_batch_keys_from_device self.datamanager_config = datamanager_config - self.pixel_sampler = None - self.ray_generator = None + self.pixel_sampler: PixelSampler = None + self.ray_generator: RayGenerator = None self._cached_collated_batch = None """_cached_collated_batch contains a collated batch of images for a specific worker that's ready for pixel sampling.""" self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker """If True, _cached_collated_batch is populated with a subset of the dataset assigned to each worker during the iteration process.""" - - def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: + + def _get_pixel_sampler(self, dataset: "TDataset", num_rays_per_batch: int) -> PixelSampler: """copy-pasta from VanillaDataManager.""" - from nerfstudio.cameras.cameras import Cameras, CameraType - from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig + from nerfstudio.cameras.cameras import CameraType + from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSamplerConfig if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: return PatchPixelSamplerConfig().setup( @@ -428,22 +436,25 @@ def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> Pi num_rays_per_batch=num_rays_per_batch, fisheye_crop_radius=fisheye_crop_radius, ) - + def _get_batch_list(self, indices=None): - """Returns a list representing a single batch from the dataset attribute. + """Returns a list representing a single batch from the dataset attribute. Each item of the list is a dictionary with dict_keys(['image_idx', 'image']) representing 1 image. This function is used to sample and load images from disk/RAM and is only called in _get_collated_batch The length of the list is equal to the (# of training images) / (num_workers)""" assert isinstance(self.input_dataset, Sized) if indices is None: + # Note: self.num_images_to_sample_from is usually -1, but _get_batch_list is usually called with indices != None. + # _get_batch_list is used by _get_collated_batch, whose indices = some partition of the dataset indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) batch_list = [] results = [] # num_threads = int(self.num_ds_load_threads) * 4 num_threads = ( - int(self.num_image_load_threads) if not self.cache_all_n_shard_per_worker + int(self.num_image_load_threads) + if not self.cache_all_n_shard_per_worker else 4 * int(self.num_image_load_threads) ) num_threads = min(num_threads, multiprocessing.cpu_count() - 1) @@ -453,6 +464,7 @@ def _get_batch_list(self, indices=None): # NB: this is I/O heavy because we are going to disk and reading an image filename # hence multi-threaded inside the worker from tqdm.auto import tqdm + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: for idx in indices: res = executor.submit(self.input_dataset.__getitem__, idx) @@ -465,16 +477,16 @@ def _get_batch_list(self, indices=None): for res in results: batch_list.append(res.result()) return batch_list - + def _get_collated_batch(self, indices=None): """Takes the output of _get_batch_list and collates them with nerfstudio_collate() Note: dict is an instance of collections.abc.Mapping - + The resulting output is collated_batch: a dictionary with dict_keys(['image_idx', 'image']) collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) - collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) + collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) """ - batch_list=self._get_batch_list(indices=indices) + batch_list = self._get_batch_list(indices=indices) # if len(batch_list) == 0: # print(indices) # print(type(batch_list[0])) # prints @@ -486,59 +498,104 @@ def _get_collated_batch(self, indices=None): return collated_batch def __iter__(self): - """Defines the iterator for the dataset.""" - # Set up stuff now that we're in the worker process - dataset_indices = list(range(len(self.input_dataset))) # this_indices has len 300, at first it is the whole training dataset, but it gets partitioned into equal chunks + """This implementation has every worker cache the indices of the images they will use to generate rays.""" + dataset_indices = list( + range(len(self.input_dataset)) + ) # this_indices has length = numTrainingImages, at first it is the whole training dataset, but it gets partitioned into equal chunks worker_info = torch.utils.data.get_worker_info() - if self.cache_all_n_shard_per_worker: # if we want every worker to cache their partition - if worker_info is None: - print('TODO log. only single worker not sharding!') - worker_id = -1 - else: - # assign this worker a deterministic uniformly sampled slice - # of the dataset - per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) - r = random.Random(1337) - r.shuffle(dataset_indices) - worker_id = worker_info.id - slice_start = worker_id * per_worker - worker_indices = dataset_indices[slice_start:slice_start+per_worker] - print(f'Worker ID {worker_id} working on {len(worker_indices)} indices') - - import time - start = time.time() - print(f"Worker ID {worker_id} caching collated batch ...") - self._cached_collated_batch = self._get_collated_batch(indices=worker_indices) - print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") - - if self.pixel_sampler is None: - self.pixel_sampler = self._get_pixel_sampler( - self.input_dataset, - self.datamanager_config.train_num_rays_per_batch - ) - if self.ray_generator is None: - self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) - - # if cache_all_n_shard_per_worker=True, every worker should have a _cached_collated_batch when the iterator was created (above lines) - # falling into this if statement means the worker's cached_collated_batch was evicted or cache_all_n_shard_per_worker=False - if self._cached_collated_batch is None: - if worker_info is None: - per_worker = len(dataset_indices) - else: - per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) - r = random.Random(1337) - r.shuffle(dataset_indices) - worker_id = 0 if worker_info is None else worker_info.id - slice_start = worker_id * per_worker - worker_indices = dataset_indices[slice_start:slice_start+per_worker] # the indices of the datapoints in the dataset this worker will load - collated_batch = self._get_collated_batch(worker_indices) - else: - collated_batch = self._cached_collated_batch + if worker_info is not None: # if we have multiple processes + per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) + slice_start = worker_info.id * per_worker + else: # we only have a single process + per_worker = len(self.input_dataset) + slice_start = 0 + worker_indices = dataset_indices[ + slice_start : slice_start + per_worker + ] # the indices of the datapoints in the dataset this worker will load + r = random.Random(3301) + loop_iterations = 32 + num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch // loop_iterations # default train_num_rays_per_batch is 4096 + worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) while True: - batch = self.pixel_sampler.sample(collated_batch) - ray_indices = batch["indices"] - ray_bundle = self.ray_generator(ray_indices) - yield ray_bundle, batch + ray_bundle_list = [] # list of RayBundle objects + batch_list = [] # list of pytorch tensors with shape torch.Size([, 3]) + for _ in range(loop_iterations): + image_indices = r.shuffle(worker_indices)[:self.datamanager_config.train_num_images_to_sample_from] # obtain num_images_per_loop + collated_batch = self._get_collated_batch(image_indices) + batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. + ray_indices = batch["indices"] + ray_bundle = self.ray_generator(ray_indices) + ray_bundle_list.append(ray_bundle) + batch_list.append(ray_bundle) + + combined_metadata = {} + concatenated_ray_bundle = RayBundle( + origins=torch.cat([ray_bundle_i.origins for ray_bundle_i in ray_bundle_list], dim=0), + directions=torch.cat([ray_bundle_i.directions for ray_bundle_i in ray_bundle_list], dim=0), + pixel_area=torch.cat([ray_bundle_i.pixel_area for ray_bundle_i in ray_bundle_list], dim=0), + camera_indices=torch.cat([ray_bundle_i.camera_indices for ray_bundle_i in ray_bundle_list], dim=0), + metadata=combined_metadata, + ) + concatenated_batch = { + "image" : torch.cat([batch_i["image"] for batch_i in batch_list], dim=0), + "indices": torch.cat([batch_i["indices"] for batch_i in batch_list], dim=0), + } + yield concatenated_ray_bundle, concatenated_batch + + # def __iter__(self): + # """Defines the iterator for the dataset.""" + # # Set up stuff now that we're in the worker process + # dataset_indices = list(range(len(self.input_dataset))) # this_indices has length = numTrainingImages, at first it is the whole training dataset, but it gets partitioned into equal chunks + # worker_info = torch.utils.data.get_worker_info() + # if self.cache_all_n_shard_per_worker: # if we want every worker to cache their partition + # if worker_info is None: + # print('TODO log. only single worker not sharding!') + # worker_id = -1 + # else: + # # assign this worker a deterministic uniformly sampled slice + # # of the dataset + # per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) + # r = random.Random(1337) + # r.shuffle(dataset_indices) + # worker_id = worker_info.id + # slice_start = worker_id * per_worker + # worker_indices = dataset_indices[slice_start:slice_start+per_worker] + # print(f'Worker ID {worker_id} working on {len(worker_indices)} indices') + + # import time + # start = time.time() + # print(f"Worker ID {worker_id} caching collated batch ...") + # self._cached_collated_batch = self._get_collated_batch(indices=worker_indices) + # print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") + + # if self.pixel_sampler is None: + # self.pixel_sampler = self._get_pixel_sampler( + # self.input_dataset, + # self.datamanager_config.train_num_rays_per_batch + # ) + # if self.ray_generator is None: + # self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) + + # # if cache_all_n_shard_per_worker=True, every worker should have a _cached_collated_batch when the iterator was created (above lines) + # # falling into this if statement means the worker's cached_collated_batch was evicted or cache_all_n_shard_per_worker=False + # if self._cached_collated_batch is None: + # if worker_info is None: + # per_worker = len(dataset_indices) + # else: + # per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) + # r = random.Random(1337) + # r.shuffle(dataset_indices) + # worker_id = 0 if worker_info is None else worker_info.id + # slice_start = worker_id * per_worker + # worker_indices = dataset_indices[slice_start:slice_start+per_worker] # the indices of the datapoints in the dataset this worker will load + # collated_batch = self._get_collated_batch(worker_indices) + # else: + # collated_batch = self._cached_collated_batch + # while True: + # batch = self.pixel_sampler.sample(collated_batch) + # ray_indices = batch["indices"] + # ray_bundle = self.ray_generator(ray_indices) + # yield ray_bundle, batch def identity(x): @@ -675,7 +732,8 @@ def setup_train(self): if self.config.use_ray_train_dataloader: import torch.multiprocessing as mp - mp.set_start_method('spawn') + + mp.set_start_method("spawn") self.raybatch_stream = RayBatchStream( input_dataset=self.train_dataset, datamanager_config=self.config, @@ -687,7 +745,7 @@ def setup_train(self): collate_fn=self.config.collate_fn, # num_workers=self.world_size * 4,# this is part of torch.utils.data.DataLoader # pin_memory=True, # this is part of torch.utils.data.DataLoader - ) + ) self.ray_dataloader = torch.utils.data.DataLoader( self.raybatch_stream, batch_size=1, @@ -709,11 +767,11 @@ def setup_train(self): num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, num_workers=self.world_size * 4 - if self.config.dataloader_num_workers == -1 - else self.config.dataloader_num_workers, + if self.config.dataloader_num_workers == -1 + else self.config.dataloader_num_workers, prefetch_factor=2 - if self.config.dataloader_prefetch_size == -1 - else self.config.dataloader_prefetch_size, + if self.config.dataloader_prefetch_size == -1 + else self.config.dataloader_prefetch_size, pin_memory=True, collate_fn=self.config.collate_fn, exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index 54998a6564..17785c29b7 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -299,7 +299,7 @@ class ProcessProjectAria: include_side_cameras: bool = False """If True, include and process the images captured by the grayscale side cameras. If False, only uses the main RGB camera's data.""" - max_dataset_size: int = 600 + max_dataset_size: int = -1 """Max number of images to train on. If the provided vrs_file has more images than max_dataset_size, images will be sampled approximately evenly. If max_dataset_size=-1, use all images available.""" From 536c6ca51c32329bd2959ebe659a93097455aef5 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 15 Jun 2024 06:13:39 -0700 Subject: [PATCH 12/78] training on a dataset with 4000 images works! --- nerfstudio/data/datamanagers/base_datamanager.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 7af465cca8..45e9402b87 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -343,7 +343,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """The limit number of batches a worker will start loading once an iterator is created. Each next() call on the iterator has the CPU prepare more batches up to this limit while the GPU is performing forward and backward passes on the model.""" - dataloader_num_workers: int = 8 + dataloader_num_workers: int = 1 """The number of workers performing the dataloading from either disk/RAM, which includes undistortion, pixel sampling, ray generation, collating, etc.""" use_ray_train_dataloader: bool = True @@ -516,17 +516,20 @@ def __iter__(self): loop_iterations = 32 num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch // loop_iterations # default train_num_rays_per_batch is 4096 worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) + if self.ray_generator is None: + self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) while True: ray_bundle_list = [] # list of RayBundle objects batch_list = [] # list of pytorch tensors with shape torch.Size([, 3]) for _ in range(loop_iterations): - image_indices = r.shuffle(worker_indices)[:self.datamanager_config.train_num_images_to_sample_from] # obtain num_images_per_loop + r.shuffle(worker_indices) + image_indices = worker_indices[:self.num_images_to_sample_from] # get a total of 'num_images_to_sample_from' image indices collated_batch = self._get_collated_batch(image_indices) batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. ray_indices = batch["indices"] ray_bundle = self.ray_generator(ray_indices) ray_bundle_list.append(ray_bundle) - batch_list.append(ray_bundle) + batch_list.append(batch) combined_metadata = {} concatenated_ray_bundle = RayBundle( @@ -739,7 +742,7 @@ def setup_train(self): datamanager_config=self.config, # self.train_pixel_sampler, # self.train_ray_generator, - num_images_to_sample_from=self.config.train_num_images_to_sample_from, + num_images_to_sample_from=100,#self.config.train_num_images_to_sample_from, # num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, # no work device=self.device, collate_fn=self.config.collate_fn, From 43a00612fd92f4c2e30291fd1855d9808314ba91 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 15 Jun 2024 07:02:45 -0700 Subject: [PATCH 13/78] some configuration speedups, loops aren't actually needed! --- .../data/datamanagers/base_datamanager.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 45e9402b87..f2d664ae3c 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -339,11 +339,11 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - dataloader_prefetch_size: int = 8 + dataloader_prefetch_size: int = 1 """The limit number of batches a worker will start loading once an iterator is created. Each next() call on the iterator has the CPU prepare more batches up to this limit while the GPU is performing forward and backward passes on the model.""" - dataloader_num_workers: int = 1 + dataloader_num_workers: int = 2 """The number of workers performing the dataloading from either disk/RAM, which includes undistortion, pixel sampling, ray generation, collating, etc.""" use_ray_train_dataloader: bool = True @@ -513,7 +513,7 @@ def __iter__(self): slice_start : slice_start + per_worker ] # the indices of the datapoints in the dataset this worker will load r = random.Random(3301) - loop_iterations = 32 + loop_iterations = 1 num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch // loop_iterations # default train_num_rays_per_batch is 4096 worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) if self.ray_generator is None: @@ -524,14 +524,32 @@ def __iter__(self): for _ in range(loop_iterations): r.shuffle(worker_indices) image_indices = worker_indices[:self.num_images_to_sample_from] # get a total of 'num_images_to_sample_from' image indices + + # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. collated_batch = self._get_collated_batch(image_indices) + + """ + Here, the variable 'batch' refers to the output of our pixel sampler. In particular + - batch is a dict_keys(['image', 'indices']) - output of pixel_sampler + - batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch. Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through. The info of what specific image index that pixel belongs to is stored within batch[’indices’] + - batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_index=camera_index, pixelRow, pixelCol) + + What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask, + and returns them as the variable `indices` which has shape `torch.Size([4096, 3])` , where each row represents a pixel (image_idx, y_pos, x_pos) + """ batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. + ray_indices = batch["indices"] ray_bundle = self.ray_generator(ray_indices) ray_bundle_list.append(ray_bundle) batch_list.append(batch) combined_metadata = {} + if "fisheye_crop_radius" in ray_bundle_list[0].metadata: + combined_metadata["fisheye_crop_radius"] = ray_bundle_list[0].metadata["fisheye_crop_radius"] + if "directions_norm" in ray_bundle_list[0].metadata: + combined_metadata["directions_norm"] = torch.cat([ray_bundle_i.metadata["directions_norm"] for ray_bundle_i in ray_bundle_list], dim=0) + concatenated_ray_bundle = RayBundle( origins=torch.cat([ray_bundle_i.origins for ray_bundle_i in ray_bundle_list], dim=0), directions=torch.cat([ray_bundle_i.directions for ray_bundle_i in ray_bundle_list], dim=0), From fa7cf306d96acb7da6235d61b324e81dfad6d132 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 15 Jun 2024 07:59:53 -0700 Subject: [PATCH 14/78] quick fix adjustment to aria --- nerfstudio/scripts/datasets/process_project_aria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index 17785c29b7..ba5d8c9f6c 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -377,7 +377,7 @@ def main(self) -> None: camera_name=names[i], pinhole=True, ) - for index in sampling_indicies + for index in range(provider.get_num_data(stream_id)) ] for i, stream_id in enumerate(stream_ids) ] From 927cb6a9faead03f96b27ce5effad25e05548556 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 16 Jun 2024 04:23:55 -0700 Subject: [PATCH 15/78] removed unnecessary looping --- .../data/datamanagers/base_datamanager.py | 121 +++--------------- 1 file changed, 19 insertions(+), 102 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index f2d664ae3c..29e3542a9a 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -339,7 +339,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - dataloader_prefetch_size: int = 1 + dataloader_prefetch_size: int = 2 """The limit number of batches a worker will start loading once an iterator is created. Each next() call on the iterator has the CPU prepare more batches up to this limit while the GPU is performing forward and backward passes on the model.""" @@ -513,110 +513,28 @@ def __iter__(self): slice_start : slice_start + per_worker ] # the indices of the datapoints in the dataset this worker will load r = random.Random(3301) - loop_iterations = 1 - num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch // loop_iterations # default train_num_rays_per_batch is 4096 + num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch # default train_num_rays_per_batch is 4096 worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) if self.ray_generator is None: self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) while True: - ray_bundle_list = [] # list of RayBundle objects - batch_list = [] # list of pytorch tensors with shape torch.Size([, 3]) - for _ in range(loop_iterations): - r.shuffle(worker_indices) - image_indices = worker_indices[:self.num_images_to_sample_from] # get a total of 'num_images_to_sample_from' image indices - - # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. - collated_batch = self._get_collated_batch(image_indices) - - """ - Here, the variable 'batch' refers to the output of our pixel sampler. In particular - - batch is a dict_keys(['image', 'indices']) - output of pixel_sampler - - batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch. Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through. The info of what specific image index that pixel belongs to is stored within batch[’indices’] - - batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_index=camera_index, pixelRow, pixelCol) - - What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask, - and returns them as the variable `indices` which has shape `torch.Size([4096, 3])` , where each row represents a pixel (image_idx, y_pos, x_pos) - """ - batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. - - ray_indices = batch["indices"] - ray_bundle = self.ray_generator(ray_indices) - ray_bundle_list.append(ray_bundle) - batch_list.append(batch) + r.shuffle(worker_indices) + image_indices = worker_indices[:self.num_images_to_sample_from] # get a total of 'num_images_to_sample_from' image indices - combined_metadata = {} - if "fisheye_crop_radius" in ray_bundle_list[0].metadata: - combined_metadata["fisheye_crop_radius"] = ray_bundle_list[0].metadata["fisheye_crop_radius"] - if "directions_norm" in ray_bundle_list[0].metadata: - combined_metadata["directions_norm"] = torch.cat([ray_bundle_i.metadata["directions_norm"] for ray_bundle_i in ray_bundle_list], dim=0) - - concatenated_ray_bundle = RayBundle( - origins=torch.cat([ray_bundle_i.origins for ray_bundle_i in ray_bundle_list], dim=0), - directions=torch.cat([ray_bundle_i.directions for ray_bundle_i in ray_bundle_list], dim=0), - pixel_area=torch.cat([ray_bundle_i.pixel_area for ray_bundle_i in ray_bundle_list], dim=0), - camera_indices=torch.cat([ray_bundle_i.camera_indices for ray_bundle_i in ray_bundle_list], dim=0), - metadata=combined_metadata, - ) - concatenated_batch = { - "image" : torch.cat([batch_i["image"] for batch_i in batch_list], dim=0), - "indices": torch.cat([batch_i["indices"] for batch_i in batch_list], dim=0), - } - yield concatenated_ray_bundle, concatenated_batch - - # def __iter__(self): - # """Defines the iterator for the dataset.""" - # # Set up stuff now that we're in the worker process - # dataset_indices = list(range(len(self.input_dataset))) # this_indices has length = numTrainingImages, at first it is the whole training dataset, but it gets partitioned into equal chunks - # worker_info = torch.utils.data.get_worker_info() - # if self.cache_all_n_shard_per_worker: # if we want every worker to cache their partition - # if worker_info is None: - # print('TODO log. only single worker not sharding!') - # worker_id = -1 - # else: - # # assign this worker a deterministic uniformly sampled slice - # # of the dataset - # per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) - # r = random.Random(1337) - # r.shuffle(dataset_indices) - # worker_id = worker_info.id - # slice_start = worker_id * per_worker - # worker_indices = dataset_indices[slice_start:slice_start+per_worker] - # print(f'Worker ID {worker_id} working on {len(worker_indices)} indices') - - # import time - # start = time.time() - # print(f"Worker ID {worker_id} caching collated batch ...") - # self._cached_collated_batch = self._get_collated_batch(indices=worker_indices) - # print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") - - # if self.pixel_sampler is None: - # self.pixel_sampler = self._get_pixel_sampler( - # self.input_dataset, - # self.datamanager_config.train_num_rays_per_batch - # ) - # if self.ray_generator is None: - # self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) - - # # if cache_all_n_shard_per_worker=True, every worker should have a _cached_collated_batch when the iterator was created (above lines) - # # falling into this if statement means the worker's cached_collated_batch was evicted or cache_all_n_shard_per_worker=False - # if self._cached_collated_batch is None: - # if worker_info is None: - # per_worker = len(dataset_indices) - # else: - # per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) - # r = random.Random(1337) - # r.shuffle(dataset_indices) - # worker_id = 0 if worker_info is None else worker_info.id - # slice_start = worker_id * per_worker - # worker_indices = dataset_indices[slice_start:slice_start+per_worker] # the indices of the datapoints in the dataset this worker will load - # collated_batch = self._get_collated_batch(worker_indices) - # else: - # collated_batch = self._cached_collated_batch - # while True: - # batch = self.pixel_sampler.sample(collated_batch) - # ray_indices = batch["indices"] - # ray_bundle = self.ray_generator(ray_indices) - # yield ray_bundle, batch + # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. + collated_batch = self._get_collated_batch(image_indices) + """ + Here, the variable 'batch' refers to the output of our pixel sampler. + - batch is a dict_keys(['image', 'indices']) + - batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch. Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through. The info of what specific image index that pixel belongs to is stored within batch[’indices’] + - batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_index=camera_index, pixelRow, pixelCol) + What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask, + and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol) + """ + batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. + ray_indices = batch["indices"] + ray_bundle = self.ray_generator(ray_indices) + yield ray_bundle, batch def identity(x): @@ -765,7 +683,6 @@ def setup_train(self): device=self.device, collate_fn=self.config.collate_fn, # num_workers=self.world_size * 4,# this is part of torch.utils.data.DataLoader - # pin_memory=True, # this is part of torch.utils.data.DataLoader ) self.ray_dataloader = torch.utils.data.DataLoader( self.raybatch_stream, @@ -776,7 +693,7 @@ def setup_train(self): pin_memory=False, # Our dataset does batching / collation collate_fn=identity, - # pin_memory_device=self.device + # pin_memory_device=self.device, ) self.iter_train_image_dataloader = None self.iter_train_raybundles = iter(self.ray_dataloader) From 814f2c21b5e3d8f00f6ed9b4be889eaf4aadddfc Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Tue, 25 Jun 2024 07:22:46 -0700 Subject: [PATCH 16/78] much faster training when adding i variable to collate every 5 ray bundles --- .../data/datamanagers/base_datamanager.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 29e3542a9a..ddcedaf4fb 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -339,10 +339,9 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - dataloader_prefetch_size: int = 2 + prefetch_factor: int = 2 """The limit number of batches a worker will start loading once an iterator is created. - Each next() call on the iterator has the CPU prepare more batches up to this - limit while the GPU is performing forward and backward passes on the model.""" + More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" dataloader_num_workers: int = 2 """The number of workers performing the dataloading from either disk/RAM, which includes undistortion, pixel sampling, ray generation, collating, etc.""" @@ -402,7 +401,6 @@ def __init__( self.num_images_to_sample_from = num_images_to_sample_from self.device = device self.collate_fn = collate_fn - # self.num_workers = kwargs.get("num_workers", 32) # nb only 4 in defaults self.num_image_load_threads = num_image_load_threads # kwargs.get("num_workers", 4) # nb only 4 in defaults self.exclude_batch_keys_from_device = exclude_batch_keys_from_device @@ -517,12 +515,15 @@ def __iter__(self): worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) if self.ray_generator is None: self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) + i = 0 while True: - r.shuffle(worker_indices) - image_indices = worker_indices[:self.num_images_to_sample_from] # get a total of 'num_images_to_sample_from' image indices - - # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. - collated_batch = self._get_collated_batch(image_indices) + if i % 5 == 0: + r.shuffle(worker_indices) + image_indices = worker_indices[:self.num_images_to_sample_from] # get a total of 'num_images_to_sample_from' image indices + + # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. + collated_batch = self._get_collated_batch(image_indices) + i += 1 """ Here, the variable 'batch' refers to the output of our pixel sampler. - batch is a dict_keys(['image', 'indices']) @@ -671,8 +672,8 @@ def setup_train(self): if self.config.use_ray_train_dataloader: import torch.multiprocessing as mp - mp.set_start_method("spawn") + self.raybatch_stream = RayBatchStream( input_dataset=self.train_dataset, datamanager_config=self.config, @@ -682,13 +683,12 @@ def setup_train(self): # num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, # no work device=self.device, collate_fn=self.config.collate_fn, - # num_workers=self.world_size * 4,# this is part of torch.utils.data.DataLoader ) self.ray_dataloader = torch.utils.data.DataLoader( self.raybatch_stream, batch_size=1, num_workers=self.config.dataloader_num_workers, - prefetch_factor=self.config.dataloader_prefetch_size, + prefetch_factor=self.config.prefetch_factor, shuffle=False, pin_memory=False, # Our dataset does batching / collation From 247ac3ebf51c07c3d5b8e6edfe14763efec7f93e Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 7 Jul 2024 07:10:55 -0700 Subject: [PATCH 17/78] cleanup unnecssary variables in Dataloader --- nerfstudio/data/datamanagers/base_datamanager.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index ddcedaf4fb..05c5e633c4 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -339,7 +339,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - prefetch_factor: int = 2 + prefetch_factor: int = 1 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" dataloader_num_workers: int = 2 @@ -677,10 +677,7 @@ def setup_train(self): self.raybatch_stream = RayBatchStream( input_dataset=self.train_dataset, datamanager_config=self.config, - # self.train_pixel_sampler, - # self.train_ray_generator, - num_images_to_sample_from=100,#self.config.train_num_images_to_sample_from, - # num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, # no work + num_images_to_sample_from=100, # self.config.train_num_images_to_sample_from, device=self.device, collate_fn=self.config.collate_fn, ) @@ -693,7 +690,7 @@ def setup_train(self): pin_memory=False, # Our dataset does batching / collation collate_fn=identity, - # pin_memory_device=self.device, + # pin_memory_device=self.device, # did not actually speed up my implementation ) self.iter_train_image_dataloader = None self.iter_train_raybundles = iter(self.ray_dataloader) From 55d0803a540d36b4320e99f3a113b66234feb435 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 11 Jul 2024 15:22:13 -0700 Subject: [PATCH 18/78] further cleanup --- .../data/datamanagers/base_datamanager.py | 2 - nerfstudio/models/bilateral_splat.py | 1110 +++++++++++++++++ 2 files changed, 1110 insertions(+), 2 deletions(-) create mode 100644 nerfstudio/models/bilateral_splat.py diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 05c5e633c4..2265f86d6a 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -485,8 +485,6 @@ def _get_collated_batch(self, indices=None): collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) """ batch_list = self._get_batch_list(indices=indices) - # if len(batch_list) == 0: - # print(indices) # print(type(batch_list[0])) # prints # print(self.collate_fn) # prints nerfstudio_collate collated_batch = self.collate_fn(batch_list) diff --git a/nerfstudio/models/bilateral_splat.py b/nerfstudio/models/bilateral_splat.py new file mode 100644 index 0000000000..6a90c337c5 --- /dev/null +++ b/nerfstudio/models/bilateral_splat.py @@ -0,0 +1,1110 @@ +# ruff: noqa: E741 +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Gaussian Splatting implementation that combines many recent advancements. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Optional, Tuple, Type, Union + +import numpy as np +import torch +from gsplat.cuda_legacy._torch_impl import quat_to_rotmat + +try: + from gsplat.rendering import rasterization +except ImportError: + print("Please install gsplat>=1.0.0") +from gsplat.cuda_legacy._wrapper import num_sh_bases +from pytorch_msssim import SSIM +from torch.nn import Parameter + +from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig +from nerfstudio.cameras.cameras import Cameras +from nerfstudio.data.scene_box import OrientedBox +from nerfstudio.engine.callbacks import ( + TrainingCallback, + TrainingCallbackAttributes, + TrainingCallbackLocation, +) +from nerfstudio.engine.optimizers import Optimizers +from nerfstudio.models.base_model import Model, ModelConfig +from nerfstudio.utils.colors import get_color +from nerfstudio.utils.misc import torch_compile +from nerfstudio.utils.rich_utils import CONSOLE +from bilags.lib_bilagrid import slice, BilateralGrid, total_variation_loss + + +def random_quat_tensor(N): + """ + Defines a random quaternion tensor of shape (N, 4) + """ + u = torch.rand(N) + v = torch.rand(N) + w = torch.rand(N) + return torch.stack( + [ + torch.sqrt(1 - u) * torch.sin(2 * math.pi * v), + torch.sqrt(1 - u) * torch.cos(2 * math.pi * v), + torch.sqrt(u) * torch.sin(2 * math.pi * w), + torch.sqrt(u) * torch.cos(2 * math.pi * w), + ], + dim=-1, + ) + + +def RGB2SH(rgb): + """ + Converts from RGB values [0,1] to the 0th spherical harmonic coefficient + """ + C0 = 0.28209479177387814 + return (rgb - 0.5) / C0 + + +def SH2RGB(sh): + """ + Converts from the 0th spherical harmonic coefficient to RGB values [0,1] + """ + C0 = 0.28209479177387814 + return sh * C0 + 0.5 + + +def resize_image(image: torch.Tensor, d: int): + """ + Downscale images using the same 'area' method in opencv + + :param image shape [H, W, C] + :param d downscale factor (must be 2, 4, 8, etc.) + + return downscaled image in shape [H//d, W//d, C] + """ + import torch.nn.functional as tf + + image = image.to(torch.float32) + weight = (1.0 / (d * d)) * torch.ones( + (1, 1, d, d), dtype=torch.float32, device=image.device + ) + return ( + tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d) + .squeeze(1) + .permute(1, 2, 0) + ) + + +@torch_compile() +def get_viewmat(optimized_camera_to_world): + """ + function that converts c2w to gsplat world2camera matrix, using compile for some speed + """ + R = optimized_camera_to_world[:, :3, :3] # 3 x 3 + T = optimized_camera_to_world[:, :3, 3:4] # 3 x 1 + # flip the z and y axes to align with gsplat conventions + R = R * torch.tensor([[[1, -1, -1]]], device=R.device, dtype=R.dtype) + # analytic matrix inverse to get world2camera matrix + R_inv = R.transpose(1, 2) + T_inv = -torch.bmm(R_inv, T) + viewmat = torch.zeros(R.shape[0], 4, 4, device=R.device, dtype=R.dtype) + viewmat[:, 3, 3] = 1.0 # homogenous + viewmat[:, :3, :3] = R_inv + viewmat[:, :3, 3:4] = T_inv + return viewmat + + +@dataclass +class BilagsModelConfig(ModelConfig): + """Splatfacto Model Config, nerfstudio's implementation of Gaussian Splatting""" + + _target: Type = field(default_factory=lambda: BilagsModel) + warmup_length: int = 500 + """period of steps where refinement is turned off""" + refine_every: int = 100 + """period of steps where gaussians are culled and densified""" + resolution_schedule: int = 3000 + """training starts at 1/d resolution, every n steps this is doubled""" + background_color: Literal["random", "black", "white"] = "random" + """Whether to randomize the background color.""" + num_downscales: int = 2 + """at the beginning, resolution is 1/2^d, where d is this number""" + cull_alpha_thresh: float = 0.1 + """threshold of opacity for culling gaussians. One can set it to a lower value (e.g. 0.005) for higher quality.""" + cull_scale_thresh: float = 0.5 + """threshold of scale for culling huge gaussians""" + continue_cull_post_densification: bool = True + """If True, continue to cull gaussians post refinement""" + reset_alpha_every: int = 30 + """Every this many refinement steps, reset the alpha""" + densify_grad_thresh: float = 0.0008 + """threshold of positional gradient norm for densifying gaussians""" + densify_size_thresh: float = 0.01 + """below this size, gaussians are *duplicated*, otherwise split""" + n_split_samples: int = 2 + """number of samples to split gaussians into""" + sh_degree_interval: int = 1000 + """every n intervals turn on another sh degree""" + cull_screen_size: float = 0.15 + """if a gaussian is more than this percent of screen space, cull it""" + split_screen_size: float = 0.05 + """if a gaussian is more than this percent of screen space, split it""" + stop_screen_size_at: int = 4000 + """stop culling/splitting at this step WRT screen size of gaussians""" + random_init: bool = False + """whether to initialize the positions uniformly randomly (not SFM points)""" + num_random: int = 50000 + """Number of gaussians to initialize if random init is used""" + random_scale: float = 10.0 + "Size of the cube to initialize random gaussians within" + ssim_lambda: float = 0.2 + """weight of ssim loss""" + stop_split_at: int = 15000 + """stop splitting at this step""" + sh_degree: int = 3 + """maximum degree of spherical harmonics to use""" + use_scale_regularization: bool = False + """If enabled, a scale regularization introduced in PhysGauss (https://xpandora.github.io/PhysGaussian/) is used for reducing huge spikey gaussians.""" + max_gauss_ratio: float = 10.0 + """threshold of ratio of gaussian max to min scale before applying regularization + loss from the PhysGaussian paper + """ + output_depth_during_training: bool = False + """If True, output depth during training. Otherwise, only output depth during evaluation.""" + rasterize_mode: Literal["classic", "antialiased"] = "classic" + """ + Classic mode of rendering will use the EWA volume splatting with a [0.3, 0.3] screen space blurring kernel. This + approach is however not suitable to render tiny gaussians at higher or lower resolution than the captured, which + results "aliasing-like" artifacts. The antialiased mode overcomes this limitation by calculating compensation factors + and apply them to the opacities of gaussians to preserve the total integrated density of splats. + + However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that + were implemented for classic mode can not render antialiased mode PLY properly without modifications. + """ + camera_optimizer: CameraOptimizerConfig = field( + default_factory=lambda: CameraOptimizerConfig(mode="off") + ) + """Config of the camera optimizer to use""" + use_bilateral_grid: bool = True + + +class BilagsModel(Model): + """Nerfstudio's implementation of Gaussian Splatting + + Args: + config: Splatfacto configuration to instantiate model + """ + + config: BilagsModelConfig + + def __init__( + self, + *args, + seed_points: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ): + self.seed_points = seed_points + super().__init__(*args, **kwargs) + + def populate_modules(self): + if self.seed_points is not None and not self.config.random_init: + means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color) + else: + means = torch.nn.Parameter( + (torch.rand((self.config.num_random, 3)) - 0.5) + * self.config.random_scale + ) + self.xys_grad_norm = None + self.max_2Dsize = None + distances, _ = self.k_nearest_sklearn(means.data, 3) + distances = torch.from_numpy(distances) + # find the average of the three nearest neighbors for each point and use that as the scale + avg_dist = distances.mean(dim=-1, keepdim=True) + scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3))) + num_points = means.shape[0] + quats = torch.nn.Parameter(random_quat_tensor(num_points)) + dim_sh = num_sh_bases(self.config.sh_degree) + + if ( + self.seed_points is not None + and not self.config.random_init + # We can have colors without points. + and self.seed_points[1].shape[0] > 0 + ): + shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda() + if self.config.sh_degree > 0: + shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255) + shs[:, 1:, 3:] = 0.0 + else: + CONSOLE.log("use color only optimization with sigmoid activation") + shs[:, 0, :3] = torch.logit(self.seed_points[1] / 255, eps=1e-10) + features_dc = torch.nn.Parameter(shs[:, 0, :]) + features_rest = torch.nn.Parameter(shs[:, 1:, :]) + else: + features_dc = torch.nn.Parameter(torch.rand(num_points, 3)) + features_rest = torch.nn.Parameter(torch.zeros((num_points, dim_sh - 1, 3))) + + opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(num_points, 1))) + self.gauss_params = torch.nn.ParameterDict( + { + "means": means, + "scales": scales, + "quats": quats, + "features_dc": features_dc, + "features_rest": features_rest, + "opacities": opacities, + } + ) + + self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup( + num_cameras=self.num_train_data, device="cpu" + ) + + # metrics + from torchmetrics.image import PeakSignalNoiseRatio + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + + self.psnr = PeakSignalNoiseRatio(data_range=1.0) + self.ssim = SSIM(data_range=1.0, size_average=True, channel=3) + self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) + self.step = 0 + + self.crop_box: Optional[OrientedBox] = None + if self.config.background_color == "random": + self.background_color = torch.tensor( + [0.1490, 0.1647, 0.2157] + ) # This color is the same as the default background color in Viser. This would only affect the background color when rendering. + else: + self.background_color = get_color(self.config.background_color) + + if self.config.use_bilateral_grid: + self.bil_grids = BilateralGrid(num=self.num_train_data) + + @property + def colors(self): + if self.config.sh_degree > 0: + return SH2RGB(self.features_dc) + else: + return torch.sigmoid(self.features_dc) + + @property + def shs_0(self): + return self.features_dc + + @property + def shs_rest(self): + return self.features_rest + + @property + def num_points(self): + return self.means.shape[0] + + @property + def means(self): + return self.gauss_params["means"] + + @property + def scales(self): + return self.gauss_params["scales"] + + @property + def quats(self): + return self.gauss_params["quats"] + + @property + def features_dc(self): + return self.gauss_params["features_dc"] + + @property + def features_rest(self): + return self.gauss_params["features_rest"] + + @property + def opacities(self): + return self.gauss_params["opacities"] + + def load_state_dict(self, dict, **kwargs): # type: ignore + # resize the parameters to match the new number of points + self.step = 30000 + if "means" in dict: + # For backwards compatibility, we remap the names of parameters from + # means->gauss_params.means since old checkpoints have that format + for p in [ + "means", + "scales", + "quats", + "features_dc", + "features_rest", + "opacities", + ]: + dict[f"gauss_params.{p}"] = dict[p] + newp = dict["gauss_params.means"].shape[0] + for name, param in self.gauss_params.items(): + old_shape = param.shape + new_shape = (newp,) + old_shape[1:] + self.gauss_params[name] = torch.nn.Parameter( + torch.zeros(new_shape, device=self.device) + ) + super().load_state_dict(dict, **kwargs) + + def k_nearest_sklearn(self, x: torch.Tensor, k: int): + """ + Find k-nearest neighbors using sklearn's NearestNeighbors. + x: The data tensor of shape [num_samples, num_features] + k: The number of neighbors to retrieve + """ + # Convert tensor to numpy array + x_np = x.cpu().numpy() + + # Build the nearest neighbors model + from sklearn.neighbors import NearestNeighbors + + nn_model = NearestNeighbors( + n_neighbors=k + 1, algorithm="auto", metric="euclidean" + ).fit(x_np) + + # Find the k-nearest neighbors + distances, indices = nn_model.kneighbors(x_np) + + # Exclude the point itself from the result and return + return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32) + + def remove_from_optim(self, optimizer, deleted_mask, new_params): + """removes the deleted_mask from the optimizer provided""" + assert len(new_params) == 1 + # assert isinstance(optimizer, torch.optim.Adam), "Only works with Adam" + + param = optimizer.param_groups[0]["params"][0] + param_state = optimizer.state[param] + del optimizer.state[param] + + # Modify the state directly without deleting and reassigning. + if "exp_avg" in param_state: + param_state["exp_avg"] = param_state["exp_avg"][~deleted_mask] + param_state["exp_avg_sq"] = param_state["exp_avg_sq"][~deleted_mask] + + # Update the parameter in the optimizer's param group. + del optimizer.param_groups[0]["params"][0] + del optimizer.param_groups[0]["params"] + optimizer.param_groups[0]["params"] = new_params + optimizer.state[new_params[0]] = param_state + + def remove_from_all_optim(self, optimizers, deleted_mask): + param_groups = self.get_gaussian_param_groups() + for group, param in param_groups.items(): + self.remove_from_optim(optimizers.optimizers[group], deleted_mask, param) + torch.cuda.empty_cache() + + def dup_in_optim(self, optimizer, dup_mask, new_params, n=2): + """adds the parameters to the optimizer""" + param = optimizer.param_groups[0]["params"][0] + param_state = optimizer.state[param] + if "exp_avg" in param_state: + repeat_dims = (n,) + tuple( + 1 for _ in range(param_state["exp_avg"].dim() - 1) + ) + param_state["exp_avg"] = torch.cat( + [ + param_state["exp_avg"], + torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat( + *repeat_dims + ), + ], + dim=0, + ) + param_state["exp_avg_sq"] = torch.cat( + [ + param_state["exp_avg_sq"], + torch.zeros_like( + param_state["exp_avg_sq"][dup_mask.squeeze()] + ).repeat(*repeat_dims), + ], + dim=0, + ) + del optimizer.state[param] + optimizer.state[new_params[0]] = param_state + optimizer.param_groups[0]["params"] = new_params + del param + + def dup_in_all_optim(self, optimizers, dup_mask, n): + param_groups = self.get_gaussian_param_groups() + for group, param in param_groups.items(): + self.dup_in_optim(optimizers.optimizers[group], dup_mask, param, n) + + def after_train(self, step: int): + assert step == self.step + # to save some training time, we no longer need to update those stats post refinement + if self.step >= self.config.stop_split_at: + return + with torch.no_grad(): + # keep track of a moving average of grad norms + visible_mask = (self.radii > 0).flatten() + grads = self.xys.absgrad[0][visible_mask].norm(dim=-1) # type: ignore + # print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}") + if self.xys_grad_norm is None: + self.xys_grad_norm = torch.zeros( + self.num_points, device=self.device, dtype=torch.float32 + ) + self.vis_counts = torch.ones( + self.num_points, device=self.device, dtype=torch.float32 + ) + assert self.vis_counts is not None + self.vis_counts[visible_mask] += 1 + self.xys_grad_norm[visible_mask] += grads + # update the max screen size, as a ratio of number of pixels + if self.max_2Dsize is None: + self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32) + newradii = self.radii.detach()[visible_mask] + self.max_2Dsize[visible_mask] = torch.maximum( + self.max_2Dsize[visible_mask], + newradii / float(max(self.last_size[0], self.last_size[1])), + ) + + def set_crop(self, crop_box: Optional[OrientedBox]): + self.crop_box = crop_box + + def set_background(self, background_color: torch.Tensor): + assert background_color.shape == (3,) + self.background_color = background_color + + def refinement_after(self, optimizers: Optimizers, step): + assert step == self.step + if self.step <= self.config.warmup_length: + return + with torch.no_grad(): + # Offset all the opacity reset logic by refine_every so that we don't + # save checkpoints right when the opacity is reset (saves every 2k) + # then cull + # only split/cull if we've seen every image since opacity reset + reset_interval = self.config.reset_alpha_every * self.config.refine_every + do_densification = ( + self.step < self.config.stop_split_at + and self.step % reset_interval + > self.num_train_data + self.config.refine_every + ) + if do_densification: + # then we densify + assert ( + self.xys_grad_norm is not None + and self.vis_counts is not None + and self.max_2Dsize is not None + ) + avg_grad_norm = ( + (self.xys_grad_norm / self.vis_counts) + * 0.5 + * max(self.last_size[0], self.last_size[1]) + ) + high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze() + splits = ( + self.scales.exp().max(dim=-1).values + > self.config.densify_size_thresh + ).squeeze() + if self.step < self.config.stop_screen_size_at: + splits |= ( + self.max_2Dsize > self.config.split_screen_size + ).squeeze() + splits &= high_grads + nsamps = self.config.n_split_samples + split_params = self.split_gaussians(splits, nsamps) + + dups = ( + self.scales.exp().max(dim=-1).values + <= self.config.densify_size_thresh + ).squeeze() + dups &= high_grads + dup_params = self.dup_gaussians(dups) + for name, param in self.gauss_params.items(): + self.gauss_params[name] = torch.nn.Parameter( + torch.cat( + [param.detach(), split_params[name], dup_params[name]], + dim=0, + ) + ) + # append zeros to the max_2Dsize tensor + self.max_2Dsize = torch.cat( + [ + self.max_2Dsize, + torch.zeros_like(split_params["scales"][:, 0]), + torch.zeros_like(dup_params["scales"][:, 0]), + ], + dim=0, + ) + + split_idcs = torch.where(splits)[0] + self.dup_in_all_optim(optimizers, split_idcs, nsamps) + + dup_idcs = torch.where(dups)[0] + self.dup_in_all_optim(optimizers, dup_idcs, 1) + + # After a guassian is split into two new gaussians, the original one should also be pruned. + splits_mask = torch.cat( + ( + splits, + torch.zeros( + nsamps * splits.sum() + dups.sum(), + device=self.device, + dtype=torch.bool, + ), + ) + ) + + deleted_mask = self.cull_gaussians(splits_mask) + elif ( + self.step >= self.config.stop_split_at + and self.config.continue_cull_post_densification + ): + deleted_mask = self.cull_gaussians() + else: + # if we donot allow culling post refinement, no more gaussians will be pruned. + deleted_mask = None + + if deleted_mask is not None: + self.remove_from_all_optim(optimizers, deleted_mask) + + if ( + self.step < self.config.stop_split_at + and self.step % reset_interval == self.config.refine_every + ): + # Reset value is set to be twice of the cull_alpha_thresh + reset_value = self.config.cull_alpha_thresh * 2.0 + self.opacities.data = torch.clamp( + self.opacities.data, + max=torch.logit( + torch.tensor(reset_value, device=self.device) + ).item(), + ) + # reset the exp of optimizer + optim = optimizers.optimizers["opacities"] + param = optim.param_groups[0]["params"][0] + param_state = optim.state[param] + param_state["exp_avg"] = torch.zeros_like(param_state["exp_avg"]) + param_state["exp_avg_sq"] = torch.zeros_like(param_state["exp_avg_sq"]) + + self.xys_grad_norm = None + self.vis_counts = None + self.max_2Dsize = None + + def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None): + """ + This function deletes gaussians with under a certain opacity threshold + extra_cull_mask: a mask indicates extra gaussians to cull besides existing culling criterion + """ + n_bef = self.num_points + # cull transparent ones + culls = ( + torch.sigmoid(self.opacities) < self.config.cull_alpha_thresh + ).squeeze() + below_alpha_count = torch.sum(culls).item() + toobigs_count = 0 + if extra_cull_mask is not None: + culls = culls | extra_cull_mask + if self.step > self.config.refine_every * self.config.reset_alpha_every: + # cull huge ones + toobigs = ( + torch.exp(self.scales).max(dim=-1).values + > self.config.cull_scale_thresh + ).squeeze() + if self.step < self.config.stop_screen_size_at: + # cull big screen space + if self.max_2Dsize is not None: + toobigs = ( + toobigs + | (self.max_2Dsize > self.config.cull_screen_size).squeeze() + ) + culls = culls | toobigs + toobigs_count = torch.sum(toobigs).item() + for name, param in self.gauss_params.items(): + self.gauss_params[name] = torch.nn.Parameter(param[~culls]) + + CONSOLE.log( + f"Culled {n_bef - self.num_points} gaussians " + f"({below_alpha_count} below alpha thresh, {toobigs_count} too bigs, {self.num_points} remaining)" + ) + + return culls + + def split_gaussians(self, split_mask, samps): + """ + This function splits gaussians that are too large + """ + n_splits = split_mask.sum().item() + CONSOLE.log( + f"Splitting {split_mask.sum().item()/self.num_points} gaussians: {n_splits}/{self.num_points}" + ) + centered_samples = torch.randn( + (samps * n_splits, 3), device=self.device + ) # Nx3 of axis-aligned scales + scaled_samples = ( + torch.exp(self.scales[split_mask].repeat(samps, 1)) * centered_samples + ) # how these scales are rotated + quats = self.quats[split_mask] / self.quats[split_mask].norm( + dim=-1, keepdim=True + ) # normalize them first + rots = quat_to_rotmat(quats.repeat(samps, 1)) # how these scales are rotated + rotated_samples = torch.bmm(rots, scaled_samples[..., None]).squeeze() + new_means = rotated_samples + self.means[split_mask].repeat(samps, 1) + # step 2, sample new colors + new_features_dc = self.features_dc[split_mask].repeat(samps, 1) + new_features_rest = self.features_rest[split_mask].repeat(samps, 1, 1) + # step 3, sample new opacities + new_opacities = self.opacities[split_mask].repeat(samps, 1) + # step 4, sample new scales + size_fac = 1.6 + new_scales = torch.log(torch.exp(self.scales[split_mask]) / size_fac).repeat( + samps, 1 + ) + self.scales[split_mask] = torch.log( + torch.exp(self.scales[split_mask]) / size_fac + ) + # step 5, sample new quats + new_quats = self.quats[split_mask].repeat(samps, 1) + out = { + "means": new_means, + "features_dc": new_features_dc, + "features_rest": new_features_rest, + "opacities": new_opacities, + "scales": new_scales, + "quats": new_quats, + } + for name, param in self.gauss_params.items(): + if name not in out: + out[name] = param[split_mask].repeat(samps, 1) + return out + + def dup_gaussians(self, dup_mask): + """ + This function duplicates gaussians that are too small + """ + n_dups = dup_mask.sum().item() + CONSOLE.log( + f"Duplicating {dup_mask.sum().item()/self.num_points} gaussians: {n_dups}/{self.num_points}" + ) + new_dups = {} + for name, param in self.gauss_params.items(): + new_dups[name] = param[dup_mask] + return new_dups + + def get_training_callbacks( + self, training_callback_attributes: TrainingCallbackAttributes + ) -> List[TrainingCallback]: + cbs = [] + cbs.append( + TrainingCallback( + [TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], self.step_cb + ) + ) + # The order of these matters + cbs.append( + TrainingCallback( + [TrainingCallbackLocation.AFTER_TRAIN_ITERATION], + self.after_train, + ) + ) + cbs.append( + TrainingCallback( + [TrainingCallbackLocation.AFTER_TRAIN_ITERATION], + self.refinement_after, + update_every_num_iters=self.config.refine_every, + args=[training_callback_attributes.optimizers], + ) + ) + return cbs + + def step_cb(self, step): + self.step = step + + def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]: + # Here we explicitly use the means, scales as parameters so that the user can override this function and + # specify more if they want to add more optimizable params to gaussians. + return { + name: [self.gauss_params[name]] + for name in [ + "means", + "scales", + "quats", + "features_dc", + "features_rest", + "opacities", + ] + } + + def get_param_groups(self) -> Dict[str, List[Parameter]]: + """Obtain the parameter groups for the optimizers + + Returns: + Mapping of different parameter groups + """ + gps = self.get_gaussian_param_groups() + gps["bil_grids"] = list(self.bil_grids.parameters()) + self.camera_optimizer.get_param_groups(param_groups=gps) + return gps + + def _get_downscale_factor(self): + if self.training: + return 2 ** max( + ( + self.config.num_downscales + - self.step // self.config.resolution_schedule + ), + 0, + ) + else: + return 1 + + def _downscale_if_required(self, image): + d = self._get_downscale_factor() + if d > 1: + return resize_image(image, d) + return image + + @staticmethod + def get_empty_outputs( + width: int, height: int, background: torch.Tensor + ) -> Dict[str, Union[torch.Tensor, List]]: + rgb = background.repeat(height, width, 1) + depth = background.new_ones(*rgb.shape[:2], 1) * 10 + accumulation = background.new_zeros(*rgb.shape[:2], 1) + return { + "rgb": rgb, + "depth": depth, + "accumulation": accumulation, + "background": background, + } + + def _get_background_color(self): + if self.config.background_color == "random": + if self.training: + background = torch.rand(3, device=self.device) + else: + background = self.background_color.to(self.device) + elif self.config.background_color == "white": + background = torch.ones(3, device=self.device) + elif self.config.background_color == "black": + background = torch.zeros(3, device=self.device) + else: + raise ValueError(f"Unknown background color {self.config.background_color}") + return background + + def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: + """Takes in a Ray Bundle and returns a dictionary of outputs. + + Args: + ray_bundle: Input bundle of rays. This raybundle should have all the + needed information to compute the outputs. + + Returns: + Outputs of model. (ie. rendered colors) + """ + if not isinstance(camera, Cameras): + print("Called get_outputs with not a camera") + return {} + + if self.training: + assert camera.shape[0] == 1, "Only one camera at a time" + optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera) + else: + optimized_camera_to_world = camera.camera_to_worlds + + # cropping + if self.crop_box is not None and not self.training: + crop_ids = self.crop_box.within(self.means).squeeze() + if crop_ids.sum() == 0: + return self.get_empty_outputs( + int(camera.width.item()), + int(camera.height.item()), + self.background_color, + ) + else: + crop_ids = None + + if crop_ids is not None: + opacities_crop = self.opacities[crop_ids] + means_crop = self.means[crop_ids] + features_dc_crop = self.features_dc[crop_ids] + features_rest_crop = self.features_rest[crop_ids] + scales_crop = self.scales[crop_ids] + quats_crop = self.quats[crop_ids] + else: + opacities_crop = self.opacities + means_crop = self.means + features_dc_crop = self.features_dc + features_rest_crop = self.features_rest + scales_crop = self.scales + quats_crop = self.quats + + colors_crop = torch.cat( + (features_dc_crop[:, None, :], features_rest_crop), dim=1 + ) + + BLOCK_WIDTH = ( + 16 # this controls the tile size of rasterization, 16 is a good default + ) + camera_scale_fac = self._get_downscale_factor() + camera.rescale_output_resolution(1 / camera_scale_fac) + viewmat = get_viewmat(optimized_camera_to_world) + K = camera.get_intrinsics_matrices().cuda() + W, H = int(camera.width.item()), int(camera.height.item()) + self.last_size = (H, W) + camera.rescale_output_resolution(camera_scale_fac) # type: ignore + + # apply the compensation of screen space blurring to gaussians + if self.config.rasterize_mode not in ["antialiased", "classic"]: + raise ValueError("Unknown rasterize_mode: %s", self.config.rasterize_mode) + + if self.config.output_depth_during_training or not self.training: + render_mode = "RGB+ED" + else: + render_mode = "RGB" + + if self.config.sh_degree > 0: + sh_degree_to_use = min( + self.step // self.config.sh_degree_interval, self.config.sh_degree + ) + else: + colors_crop = torch.sigmoid(colors_crop) + sh_degree_to_use = None + + render, alpha, info = rasterization( + means=means_crop, + quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True), + scales=torch.exp(scales_crop), + opacities=torch.sigmoid(opacities_crop).squeeze(-1), + colors=colors_crop, + viewmats=viewmat, # [1, 4, 4] + Ks=K, # [1, 3, 3] + width=W, + height=H, + tile_size=BLOCK_WIDTH, + packed=False, + near_plane=0.01, + far_plane=1e10, + render_mode=render_mode, + sh_degree=sh_degree_to_use, + sparse_grad=False, + absgrad=True, + rasterize_mode=self.config.rasterize_mode, + # set some threshold to disregrad small gaussians for faster rendering. + # radius_clip=3.0, + ) + if self.training and info["means2d"].requires_grad: + info["means2d"].retain_grad() + self.xys = info["means2d"] # [1, N, 2] + self.radii = info["radii"][0] # [N] + alpha = alpha[:, ...] + + background = self._get_background_color() + rgb = render[:, ..., :3] + (1 - alpha) * background + rgb = torch.clamp(rgb, 0.0, 1.0) + + # apply bilateral grid + if self.config.use_bilateral_grid and self.training: + if camera.metadata is not None and "cam_idx" in camera.metadata: + cam_idx = camera.metadata["cam_idx"] + if cam_idx != 0: + # make xy grid + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, 1.0, H), + torch.linspace(0, 1.0, W), + indexing="ij", + ) + grid_xy = ( + torch.stack([grid_x, grid_y], dim=-1) + .unsqueeze(0) + .to(self.device) + ) + + # prepare grid idx + # grid_idx = ( + # torch.ones((H, W), dtype=torch.long, device=self.device) * cam_idx + # ) + # grid_idx = grid_idx.unsqueeze(-1) # [H, W, 1] + # process rgb + out = slice( + bil_grids=self.bil_grids, + rgb=rgb, + xy=grid_xy, + grid_idx=torch.tensor( + cam_idx, device=self.device, dtype=torch.long + ), + ) + rgb = out["rgb"] + + if render_mode == "RGB+ED": + depth_im = render[:, ..., 3:4] + depth_im = torch.where( + alpha > 0, depth_im, depth_im.detach().max() + ).squeeze(0) + else: + depth_im = None + + if background.shape[0] == 3 and not self.training: + background = background.expand(H, W, 3) + + return { + "rgb": rgb.squeeze(0), # type: ignore + "depth": depth_im, # type: ignore + "accumulation": alpha.squeeze(0), # type: ignore + "background": background, # type: ignore + } # type: ignore + + def get_gt_img(self, image: torch.Tensor): + """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose + + Args: + image: tensor.Tensor in type uint8 or float32 + """ + if image.dtype == torch.uint8: + image = image.float() / 255.0 + gt_img = self._downscale_if_required(image) + return gt_img.to(self.device) + + def composite_with_background(self, image, background) -> torch.Tensor: + """Composite the ground truth image with a background color when it has an alpha channel. + + Args: + image: the image to composite + background: the background color + """ + if image.shape[2] == 4: + alpha = image[..., -1].unsqueeze(-1).repeat((1, 1, 3)) + return alpha * image[..., :3] + (1 - alpha) * background + else: + return image + + def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: + """Compute and returns metrics. + + Args: + outputs: the output to compute loss dict to + batch: ground truth batch corresponding to outputs + """ + gt_rgb = self.composite_with_background( + self.get_gt_img(batch["image"]), outputs["background"] + ) + metrics_dict = {} + predicted_rgb = outputs["rgb"] + metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb) + + metrics_dict["gaussian_count"] = self.num_points + + self.camera_optimizer.get_metrics_dict(metrics_dict) + return metrics_dict + + def get_loss_dict( + self, outputs, batch, metrics_dict=None + ) -> Dict[str, torch.Tensor]: + """Computes and returns the losses dict. + + Args: + outputs: the output to compute loss dict to + batch: ground truth batch corresponding to outputs + metrics_dict: dictionary of metrics, some of which we can use for loss + """ + gt_img = self.composite_with_background( + self.get_gt_img(batch["image"]), outputs["background"] + ) + pred_img = outputs["rgb"] + + # Set masked part of both ground-truth and rendered image to black. + # This is a little bit sketchy for the SSIM loss. + if "mask" in batch: + # batch["mask"] : [H, W, 1] + mask = self._downscale_if_required(batch["mask"]) + mask = mask.to(self.device) + assert mask.shape[:2] == gt_img.shape[:2] == pred_img.shape[:2] + gt_img = gt_img * mask + pred_img = pred_img * mask + + Ll1 = torch.abs(gt_img - pred_img).mean() + simloss = 1 - self.ssim( + gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...] + ) + if self.config.use_scale_regularization and self.step % 10 == 0: + scale_exp = torch.exp(self.scales) + scale_reg = ( + torch.maximum( + scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1), + torch.tensor(self.config.max_gauss_ratio), + ) + - self.config.max_gauss_ratio + ) + scale_reg = 0.1 * scale_reg.mean() + else: + scale_reg = torch.tensor(0.0).to(self.device) + + loss_dict = { + "main_loss": (1 - self.config.ssim_lambda) * Ll1 + + self.config.ssim_lambda * simloss, + "scale_reg": scale_reg, + } + + if self.training: + # Add loss from camera optimizer + self.camera_optimizer.get_loss_dict(loss_dict) + loss_dict["tv_loss"] = 10 * total_variation_loss(self.bil_grids.grids) + + return loss_dict + + @torch.no_grad() + def get_outputs_for_camera( + self, camera: Cameras, obb_box: Optional[OrientedBox] = None + ) -> Dict[str, torch.Tensor]: + """Takes in a camera, generates the raybundle, and computes the output of the model. + Overridden for a camera-based gaussian model. + + Args: + camera: generates raybundle + """ + assert camera is not None, "must provide camera to gaussian model" + self.set_crop(obb_box) + outs = self.get_outputs(camera.to(self.device)) + return outs # type: ignore + + def get_image_metrics_and_images( + self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] + ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: + """Writes the test image outputs. + + Args: + image_idx: Index of the image. + step: Current step. + batch: Batch of data. + outputs: Outputs of the model. + + Returns: + A dictionary of metrics. + """ + gt_rgb = self.composite_with_background( + self.get_gt_img(batch["image"]), outputs["background"] + ) + predicted_rgb = outputs["rgb"] + + combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) + + # Switch images from [H, W, C] to [1, C, H, W] for metrics computations + gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...] + predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...] + + psnr = self.psnr(gt_rgb, predicted_rgb) + ssim = self.ssim(gt_rgb, predicted_rgb) + lpips = self.lpips(gt_rgb, predicted_rgb) + + # all of these metrics will be logged as scalars + metrics_dict = {"psnr": float(psnr.item()), "ssim": float(ssim)} # type: ignore + metrics_dict["lpips"] = float(lpips) + + images_dict = {"img": combined_rgb} + + return metrics_dict, images_dict From b6979a4f3bf2534340cb45b1b7ccbe6ed54c4241 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 20 Jul 2024 11:29:47 -0700 Subject: [PATCH 19/78] adding caching of compressed images to RAM to reduce disk bottleneck --- .../data/datamanagers/base_datamanager.py | 45 ++++++++++++------- nerfstudio/data/datasets/base_dataset.py | 16 ++++++- nerfstudio/data/utils/dataloaders.py | 32 ++++++++++++- 3 files changed, 73 insertions(+), 20 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 2265f86d6a..432c880038 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -39,7 +39,7 @@ get_args, get_origin, ) - +import time import torch import tyro from torch import nn @@ -317,9 +317,9 @@ class VanillaDataManagerConfig(DataManagerConfig): """Specifies the dataparser used to unpack the data.""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" - train_num_images_to_sample_from: int = -1 + train_num_images_to_sample_from: int = -1 # usually -1 """Number of images to sample during training iteration.""" - train_num_times_to_repeat_images: int = -1 + train_num_times_to_repeat_images: int = -1 # usually -1 """When not training on all images, number of iterations before picking new images. If -1, never pick new images.""" eval_num_rays_per_batch: int = 1024 @@ -339,13 +339,13 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - prefetch_factor: int = 1 + prefetch_factor: int = 2 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" - dataloader_num_workers: int = 2 + dataloader_num_workers: int = 1 """The number of workers performing the dataloading from either disk/RAM, which - includes undistortion, pixel sampling, ray generation, collating, etc.""" - use_ray_train_dataloader: bool = True + includes collating, pixel sampling, unprojecting, ray generation etc.""" + use_ray_train_dataloader: bool = False """Allows parallelization of the dataloading process with multiple workers.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. @@ -397,6 +397,7 @@ def __init__( assert isinstance(self.input_dataset, Sized) # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) + """If True, cache all images to RAM as a collated""" # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from self.num_images_to_sample_from = num_images_to_sample_from self.device = device @@ -408,9 +409,9 @@ def __init__( self.pixel_sampler: PixelSampler = None self.ray_generator: RayGenerator = None self._cached_collated_batch = None - """_cached_collated_batch contains a collated batch of images for a specific worker that's ready for pixel sampling.""" + """self._cached_collated_batch contains a collated batch of images for a specific worker that's ready for pixel sampling.""" self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker - """If True, _cached_collated_batch is populated with a subset of the dataset assigned to each worker during the iteration process.""" + """If True, self._cached_collated_batch is populated with a subset of the dataset assigned to each worker during the iteration process.""" def _get_pixel_sampler(self, dataset: "TDataset", num_rays_per_batch: int) -> PixelSampler: """copy-pasta from VanillaDataManager.""" @@ -457,12 +458,11 @@ def _get_batch_list(self, indices=None): ) num_threads = min(num_threads, multiprocessing.cpu_count() - 1) num_threads = max(num_threads, 1) - # print('num_threads', num_threads) + # print('num_threads', num_threads) # prints 16 # NB: this is I/O heavy because we are going to disk and reading an image filename # hence multi-threaded inside the worker from tqdm.auto import tqdm - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: for idx in indices: res = executor.submit(self.input_dataset.__getitem__, idx) @@ -484,10 +484,21 @@ def _get_collated_batch(self, indices=None): collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) """ + start_time = time.time() batch_list = self._get_batch_list(indices=indices) + end_time = time.time() + batch_time = (end_time - start_time) * 1000 # Convert to milliseconds + with open('image_read_time.txt', 'w') as f: + f.write(f"Time to read images and create `batch_list`:: {batch_time:.2f} ms") # print(type(batch_list[0])) # prints # print(self.collate_fn) # prints nerfstudio_collate + # start_time = time.time() collated_batch = self.collate_fn(batch_list) + # end_time = time.time() + # collate_time = (end_time - start_time) * 1000 # Convert to milliseconds + # with open('collate_time.txt', 'w') as f: + # f.write(f"Time to collate images: {collate_time:.2f} ms") + # print(type(collated_batch)) # prints collated_batch = get_dict_to_torch( collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device ) @@ -675,7 +686,7 @@ def setup_train(self): self.raybatch_stream = RayBatchStream( input_dataset=self.train_dataset, datamanager_config=self.config, - num_images_to_sample_from=100, # self.config.train_num_images_to_sample_from, + num_images_to_sample_from=self.config.train_num_images_to_sample_from, device=self.device, collate_fn=self.config.collate_fn, ) @@ -695,16 +706,16 @@ def setup_train(self): else: self.iter_train_raybundles = None self.train_image_dataloader = CacheDataloader( - self.train_dataset, - num_images_to_sample_from=self.config.train_num_images_to_sample_from, - num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, + self.train_dataset, + num_images_to_sample_from=self.config.train_num_images_to_sample_from, # batch_size + num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, # device=self.device, num_workers=self.world_size * 4 if self.config.dataloader_num_workers == -1 else self.config.dataloader_num_workers, prefetch_factor=2 - if self.config.dataloader_prefetch_size == -1 - else self.config.dataloader_prefetch_size, + if self.config.prefetch_factor == -1 + else self.config.prefetch_factor, pin_memory=True, collate_fn=self.config.collate_fn, exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index e16ea33482..272ddff952 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -18,6 +18,7 @@ from __future__ import annotations from copy import deepcopy +import io from pathlib import Path from typing import Dict, List, Literal @@ -45,7 +46,7 @@ class InputDataset(Dataset): exclude_batch_keys_from_device: List[str] = ["image", "mask"] cameras: Cameras - def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0): + def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_images: bool = True): super().__init__() self._dataparser_outputs = dataparser_outputs self.scale_factor = scale_factor @@ -54,6 +55,13 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = self.cameras = deepcopy(dataparser_outputs.cameras) self.cameras.rescale_output_resolution(scaling_factor=scale_factor) self.mask_color = dataparser_outputs.metadata.get("mask_color", None) + self.cache_images = cache_images + """If cache_images == True, cache all the image files into RAM in their compressed form (not as tensors yet)""" + if cache_images: + self.binary_images = [] + for image_filename in self._dataparser_outputs.image_filenames: + with open(image_filename, 'rb') as f: + self.binary_images.append(io.BytesIO(f.read())) def __len__(self): return len(self._dataparser_outputs.image_filenames) @@ -65,7 +73,10 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: image_idx: The image index in the dataset. """ image_filename = self._dataparser_outputs.image_filenames[image_idx] - pil_image = Image.open(image_filename) + if self.cache_images: + pil_image = Image.open(self.binary_images[image_idx]) + else: + pil_image = Image.open(image_filename) if self.scale_factor != 1.0: width, height = pil_image.size newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) @@ -159,3 +170,4 @@ def image_filenames(self) -> List[Path]: """ return self._dataparser_outputs.image_filenames + \ No newline at end of file diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 3a1d744f1d..8c75566f48 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -44,7 +44,7 @@ class CacheDataloader(DataLoader): Args: dataset: Dataset to sample from. num_samples_to_collate: How many images to sample rays for each batch. -1 for all images. - num_times_to_repeat_images: How often to collate new images. -1 to never pick new images. + num_times_to_repeat_images: How many ray bundles to . -1 to never pick new images. device: Device to perform computation. collate_fn: The function we will use to collate our training data """ @@ -134,6 +134,7 @@ def __iter__(self): collated_batch = self.cached_collated_batch elif self.first_time or ( self.num_times_to_repeat_images != -1 and self.num_repeated >= self.num_times_to_repeat_images + # if it's the first time, we need to ): # trigger a reset self.num_repeated = 0 @@ -147,6 +148,35 @@ def __iter__(self): yield collated_batch +import torch +class ParallelCacheDataloader(torch.utils.data.IterableDataset): + """Creates batches of the InputDataset return type with multiple workers, can be toggled to return image batches or RayBundles + When return image batches + """ + def __init__( + self, + input_dataset: Dataset, + num_images_to_sample_from: int = -1, + device: Union[torch.device, str] = "cpu", + collate_fn: Callable[[Any], Any] = nerfstudio_collate, + exclude_batch_keys_from_device: Optional[List[str]] = None, + num_image_load_threads : int = 2, + cache_all_n_shard_per_worker : bool = True, + ): + if exclude_batch_keys_from_device is None: + exclude_batch_keys_from_device = ["image"] + self.input_dataset = input_dataset + assert isinstance(self.input_dataset, Sized) + + self.num_images_to_sample_from = num_images_to_sample_from + """The size of a collated_batch of images""" + + def __iter__(self): + pass + + + + # class RayBatchStream(torch.utils.data.IterableDataset): # def __init__( # self, From 81dbf7cb682a7f83a76fca8748f3b4d15b5d9d36 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 21 Jul 2024 17:16:48 -0700 Subject: [PATCH 20/78] added caching to RAM for masks --- .../data/datamanagers/base_datamanager.py | 22 ++++++++++++------- nerfstudio/data/datasets/base_dataset.py | 9 +++++++- nerfstudio/data/utils/data_utils.py | 4 ++-- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 432c880038..50e0cd46e8 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -317,9 +317,9 @@ class VanillaDataManagerConfig(DataManagerConfig): """Specifies the dataparser used to unpack the data.""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" - train_num_images_to_sample_from: int = -1 # usually -1 + train_num_images_to_sample_from: int = 100 # usually -1 """Number of images to sample during training iteration.""" - train_num_times_to_repeat_images: int = -1 # usually -1 + train_num_times_to_repeat_images: int = 5 # usually -1 """When not training on all images, number of iterations before picking new images. If -1, never pick new images.""" eval_num_rays_per_batch: int = 1024 @@ -342,10 +342,10 @@ class VanillaDataManagerConfig(DataManagerConfig): prefetch_factor: int = 2 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" - dataloader_num_workers: int = 1 + dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - use_ray_train_dataloader: bool = False + use_ray_train_dataloader: bool = True """Allows parallelization of the dataloading process with multiple workers.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. @@ -384,6 +384,7 @@ def __init__( input_dataset: Dataset, datamanager_config: DataManagerConfig, num_images_to_sample_from: int = -1, # passed in from VanillaDataManager + num_times_to_repeat_images: int = -1, # passed in from VanillaDataManager device: Union[torch.device, str] = "cpu", collate_fn: Callable[[Any], Any] = nerfstudio_collate, exclude_batch_keys_from_device: Optional[List[str]] = None, @@ -400,6 +401,7 @@ def __init__( """If True, cache all images to RAM as a collated""" # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from self.num_images_to_sample_from = num_images_to_sample_from + self.num_times_to_repeat_images = num_times_to_repeat_images self.device = device self.collate_fn = collate_fn self.num_image_load_threads = num_image_load_threads # kwargs.get("num_workers", 4) # nb only 4 in defaults @@ -526,10 +528,13 @@ def __iter__(self): self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) i = 0 while True: - if i % 5 == 0: + if i % self.num_times_to_repeat_images == 0: r.shuffle(worker_indices) - image_indices = worker_indices[:self.num_images_to_sample_from] # get a total of 'num_images_to_sample_from' image indices - + # get a total of 'num_images_to_sample_from' image indices, if self.num_images_to_sample_from + if self.num_images_to_sample_from == -1: + image_indices = worker_indices + else: + image_indices = worker_indices[:self.num_images_to_sample_from] # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. collated_batch = self._get_collated_batch(image_indices) i += 1 @@ -687,6 +692,7 @@ def setup_train(self): input_dataset=self.train_dataset, datamanager_config=self.config, num_images_to_sample_from=self.config.train_num_images_to_sample_from, + num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, collate_fn=self.config.collate_fn, ) @@ -699,7 +705,7 @@ def setup_train(self): pin_memory=False, # Our dataset does batching / collation collate_fn=identity, - # pin_memory_device=self.device, # did not actually speed up my implementation + pin_memory_device=self.device, # did not actually speed up my implementation ) self.iter_train_image_dataloader = None self.iter_train_raybundles = iter(self.ray_dataloader) diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index 272ddff952..e0c89e70ea 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -59,9 +59,13 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = """If cache_images == True, cache all the image files into RAM in their compressed form (not as tensors yet)""" if cache_images: self.binary_images = [] + self.binary_masks = [] for image_filename in self._dataparser_outputs.image_filenames: with open(image_filename, 'rb') as f: self.binary_images.append(io.BytesIO(f.read())) + for mask_filename in self._dataparser_outputs.mask_filenames: + with open(mask_filename, 'rb') as f: + self.binary_masks.append(io.BytesIO(f.read())) def __len__(self): return len(self._dataparser_outputs.image_filenames) @@ -136,7 +140,10 @@ def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "fl data = {"image_idx": image_idx, "image": image} if self._dataparser_outputs.mask_filenames is not None: - mask_filepath = self._dataparser_outputs.mask_filenames[image_idx] + if self.cache_images: + mask_filepath = self.binary_masks[image_idx] + else: + mask_filepath = self._dataparser_outputs.mask_filenames[image_idx] data["mask"] = get_image_mask_tensor_from_path(filepath=mask_filepath, scale_factor=self.scale_factor) assert ( data["mask"].shape[:2] == data["image"].shape[:2] diff --git a/nerfstudio/data/utils/data_utils.py b/nerfstudio/data/utils/data_utils.py index c81101c4f3..e114ccd767 100644 --- a/nerfstudio/data/utils/data_utils.py +++ b/nerfstudio/data/utils/data_utils.py @@ -14,7 +14,7 @@ """Utility functions to allow easy re-use of common operations across dataloaders""" from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Tuple, Union, IO import cv2 import numpy as np @@ -22,7 +22,7 @@ from PIL import Image -def get_image_mask_tensor_from_path(filepath: Path, scale_factor: float = 1.0) -> torch.Tensor: +def get_image_mask_tensor_from_path(filepath: Union[Path, IO[bytes]], scale_factor: float = 1.0) -> torch.Tensor: """ Utility function to read a mask image from the given path and return a boolean tensor """ From 55ca71d559eed5745763d4d6755a9529e498f29d Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 26 Jul 2024 02:20:31 -0700 Subject: [PATCH 21/78] found fast way to collate - many tricks applied --- .../data/datamanagers/base_datamanager.py | 68 +++++++++++-------- nerfstudio/data/datasets/base_dataset.py | 18 ++--- nerfstudio/data/utils/data_utils.py | 33 ++++++++- 3 files changed, 79 insertions(+), 40 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 50e0cd46e8..21b52b25d1 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -56,7 +56,7 @@ from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import ( # , RayBatchStream +from nerfstudio.data.utils.dataloaders import ( CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader, @@ -339,7 +339,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - prefetch_factor: int = 2 + prefetch_factor: int = 4 # increasing prefetch_factor was not beneficial """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" dataloader_num_workers: int = 4 @@ -347,6 +347,8 @@ class VanillaDataManagerConfig(DataManagerConfig): includes collating, pixel sampling, unprojecting, ray generation etc.""" use_ray_train_dataloader: bool = True """Allows parallelization of the dataloading process with multiple workers.""" + cache_binaries: bool = True + """If True, caches the images as binary strings to RAM""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -376,7 +378,9 @@ def __post_init__(self): from torch.utils.data import Dataset from nerfstudio.utils.misc import get_dict_to_torch +from tqdm.auto import tqdm +from torch.profiler import profile, record_function, ProfilerActivity class RayBatchStream(torch.utils.data.IterableDataset): def __init__( @@ -389,8 +393,7 @@ def __init__( collate_fn: Callable[[Any], Any] = nerfstudio_collate, exclude_batch_keys_from_device: Optional[List[str]] = None, num_image_load_threads: int = 4, - cache_all_n_shard_per_worker: bool = True, # When False, always getting Killed/bugs for some reason... why? - # when cache_all_n_shard_per_worker True, getting killed because caching everything is not good + cache_all_n_shard_per_worker: bool = True, ): if exclude_batch_keys_from_device is None: exclude_batch_keys_from_device = ["image"] @@ -462,20 +465,22 @@ def _get_batch_list(self, indices=None): num_threads = max(num_threads, 1) # print('num_threads', num_threads) # prints 16 + + # NB: this is I/O heavy because we are going to disk and reading an image filename # hence multi-threaded inside the worker - from tqdm.auto import tqdm with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: for idx in indices: res = executor.submit(self.input_dataset.__getitem__, idx) results.append(res) - # for res in track(results, description="Loading data batch", transient=True): # for res in tqdm(results, desc='_get_batch_list'): - if self.cache_all_n_shard_per_worker: - results = tqdm(results) + results = tqdm(results) # does not effect times, tested many times for res in results: batch_list.append(res.result()) + + # for idx in tqdm(indices): # this is slower compared to using threads + # batch_list.append(self.input_dataset.__getitem__(idx)) return batch_list def _get_collated_batch(self, indices=None): @@ -486,24 +491,16 @@ def _get_collated_batch(self, indices=None): collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) """ - start_time = time.time() - batch_list = self._get_batch_list(indices=indices) - end_time = time.time() - batch_time = (end_time - start_time) * 1000 # Convert to milliseconds - with open('image_read_time.txt', 'w') as f: - f.write(f"Time to read images and create `batch_list`:: {batch_time:.2f} ms") - # print(type(batch_list[0])) # prints + with record_function("_get_batch_list"): + batch_list = self._get_batch_list(indices=indices) + # # print(type(batch_list[0])) # prints # print(self.collate_fn) # prints nerfstudio_collate - # start_time = time.time() - collated_batch = self.collate_fn(batch_list) - # end_time = time.time() - # collate_time = (end_time - start_time) * 1000 # Convert to milliseconds - # with open('collate_time.txt', 'w') as f: - # f.write(f"Time to collate images: {collate_time:.2f} ms") - # print(type(collated_batch)) # prints - collated_batch = get_dict_to_torch( - collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device - ) + with record_function("nerfstudio_collate"): + collated_batch = self.collate_fn(batch_list) + with record_function("sending to GPU"): + collated_batch = get_dict_to_torch( + collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device + ) return collated_batch def __iter__(self): @@ -521,6 +518,8 @@ def __iter__(self): worker_indices = dataset_indices[ slice_start : slice_start + per_worker ] # the indices of the datapoints in the dataset this worker will load + if self.cache_all_n_shard_per_worker: + self._cached_collated_batch = self._get_collated_batch(worker_indices) r = random.Random(3301) num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch # default train_num_rays_per_batch is 4096 worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) @@ -528,15 +527,23 @@ def __iter__(self): self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) i = 0 while True: - if i % self.num_times_to_repeat_images == 0: + if self.cache_all_n_shard_per_worker: + collated_batch = self._cached_collated_batch + elif i % self.num_times_to_repeat_images == 0: r.shuffle(worker_indices) - # get a total of 'num_images_to_sample_from' image indices, if self.num_images_to_sample_from - if self.num_images_to_sample_from == -1: + + if self.num_images_to_sample_from == -1: # if -1, the worker gets all available indices in its partition image_indices = worker_indices - else: + else: # get a total of 'num_images_to_sample_from' image indices image_indices = worker_indices[:self.num_images_to_sample_from] # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. + # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: + # with record_function("process_images"): collated_batch = self._get_collated_batch(image_indices) + # with open('_get_batch_list_profile.txt', 'w') as f: + # f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + # f.write("\n\nMemory Usage:\n") + # f.write(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20)) i += 1 """ Here, the variable 'batch' refers to the output of our pixel sampler. @@ -695,6 +702,7 @@ def setup_train(self): num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, collate_fn=self.config.collate_fn, + cache_all_n_shard_per_worker=False, ) self.ray_dataloader = torch.utils.data.DataLoader( self.raybatch_stream, @@ -702,7 +710,7 @@ def setup_train(self): num_workers=self.config.dataloader_num_workers, prefetch_factor=self.config.prefetch_factor, shuffle=False, - pin_memory=False, + pin_memory=True, # Our dataset does batching / collation collate_fn=identity, pin_memory_device=self.device, # did not actually speed up my implementation diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index e0c89e70ea..e535bffb38 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -32,8 +32,8 @@ from nerfstudio.cameras.cameras import Cameras from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs -from nerfstudio.data.utils.data_utils import get_image_mask_tensor_from_path - +from nerfstudio.data.utils.data_utils import get_image_mask_tensor_from_path, pil_to_numpy +from torch.profiler import record_function class InputDataset(Dataset): """Dataset that returns images. @@ -63,9 +63,10 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = for image_filename in self._dataparser_outputs.image_filenames: with open(image_filename, 'rb') as f: self.binary_images.append(io.BytesIO(f.read())) - for mask_filename in self._dataparser_outputs.mask_filenames: - with open(mask_filename, 'rb') as f: - self.binary_masks.append(io.BytesIO(f.read())) + if self._dataparser_outputs.mask_filenames is not None: + for mask_filename in self._dataparser_outputs.mask_filenames: + with open(mask_filename, 'rb') as f: + self.binary_masks.append(io.BytesIO(f.read())) def __len__(self): return len(self._dataparser_outputs.image_filenames) @@ -85,11 +86,10 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: width, height = pil_image.size newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) pil_image = pil_image.resize(newsize, resample=Image.Resampling.BILINEAR) - image = np.array(pil_image, dtype="uint8") # shape is (h, w) or (h, w, 3 or 4) + image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "float32" if len(image.shape) == 2: image = image[:, :, None].repeat(3, axis=2) assert len(image.shape) == 3 - assert image.dtype == np.uint8 assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is in correct." return image @@ -99,7 +99,7 @@ def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image Args: image_idx: The image index in the dataset. """ - image = torch.from_numpy(self.get_numpy_image(image_idx).astype("float32") / 255.0) + image = torch.from_numpy(self.get_numpy_image(image_idx)) if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: assert (self._dataparser_outputs.alpha_color >= 0).all() and ( self._dataparser_outputs.alpha_color <= 1 @@ -113,7 +113,7 @@ def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_w Args: image_idx: The image index in the dataset. """ - image = torch.from_numpy(self.get_numpy_image(image_idx)) + image = torch.from_numpy(self.get_numpy_image(image_idx).astype(np.uint8)) if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: assert (self._dataparser_outputs.alpha_color >= 0).all() and ( self._dataparser_outputs.alpha_color <= 1 diff --git a/nerfstudio/data/utils/data_utils.py b/nerfstudio/data/utils/data_utils.py index e114ccd767..bb62fdd968 100644 --- a/nerfstudio/data/utils/data_utils.py +++ b/nerfstudio/data/utils/data_utils.py @@ -22,6 +22,37 @@ from PIL import Image +def pil_to_numpy(im: Image) -> np.ndarray: + """Converts a PIL Image object to a NumPy array. + + Args: + im (PIL.Image.Image): The input PIL Image object. + + Returns: + numpy.ndarray: float 32 ndarray representing the image normalized between 0 and 1. + """ + im.load() + + # Unpack data + e = Image._getencoder(im.mode, "raw", im.mode) + e.setimage(im.im) + + # NumPy buffer for the result + shape, typestr = Image._conv_type_shape(im) + data = np.empty(shape, dtype=np.dtype(typestr)) + mem = data.data.cast("B", (data.data.nbytes,)) + + bufsize, s, offset = 65536, 0, 0 + while not s: + l, s, d = e.encode(bufsize) + mem[offset:offset + len(d)] = d + offset += len(d) + if s < 0: + raise RuntimeError("encoder error %d in tobytes" % s) + + return data / np.float32(255) + + def get_image_mask_tensor_from_path(filepath: Union[Path, IO[bytes]], scale_factor: float = 1.0) -> torch.Tensor: """ Utility function to read a mask image from the given path and return a boolean tensor @@ -31,7 +62,7 @@ def get_image_mask_tensor_from_path(filepath: Union[Path, IO[bytes]], scale_fact width, height = pil_mask.size newsize = (int(width * scale_factor), int(height * scale_factor)) pil_mask = pil_mask.resize(newsize, resample=Image.Resampling.NEAREST) - mask_tensor = torch.from_numpy(np.array(pil_mask)).unsqueeze(-1).bool() + mask_tensor = torch.from_numpy(pil_to_numpy(pil_mask)).unsqueeze(-1).bool() if len(mask_tensor.shape) != 3: raise ValueError("The mask image should have 1 channel") return mask_tensor From 3b4f091f6b90e99d05f42e8eebd90947c62ebde4 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 26 Jul 2024 15:56:24 -0700 Subject: [PATCH 22/78] quick update to aria to test on different datasets --- .../scripts/datasets/process_project_aria.py | 73 ++++++++++++++----- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index ba5d8c9f6c..7947d3c20c 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -320,13 +320,33 @@ def main(self) -> None: "frames": [], } points = [] + names = ["camera-rgb", "camera-slam-left", "camera-slam-right"] + total_num_rgb_images_per_recording_list = [] + total_num_images_per_recording_list = [] + + # Count the total number of images per dataset + for rec_i, (vrs_file, mps_data_dir, points_file) in enumerate(vrs_mps_points_triplets): + provider = create_vrs_data_provider(str(vrs_file.absolute())) + assert provider is not None, "Cannot open file" + stream_ids = [provider.get_stream_id_from_label(name) for name in names] + total_num_rgb_images_per_recording_list.append(provider.get_num_data(stream_ids[0])) + total_num_images_per_recording_list.append( + sum([provider.get_num_data(stream_id) for stream_id in stream_ids]) + ) + if not self.include_side_cameras: + assert self.max_dataset_size <= sum( + total_num_rgb_images_per_recording_list + ), "Specify a dataset size at most the number of RGB images provided" + else: + assert self.max_dataset_size <= sum( + total_num_images_per_recording_list + ), "Specify a dataset size at most the number of images provided" # Process the aria data of each user one by one for rec_i, (vrs_file, mps_data_dir, points_file) in enumerate(vrs_mps_points_triplets): provider = create_vrs_data_provider(str(vrs_file.absolute())) assert provider is not None, "Cannot open file" - names = ["camera-rgb", "camera-slam-left", "camera-slam-right"] name_to_camera = { name: get_camera_calibs(provider, name) # type: ignore for name in names @@ -342,22 +362,20 @@ def main(self) -> None: print(f"Creating Aria frames for recording {rec_i + 1}...") CANONICAL_RGB_VALID_RADIUS = 707.5 # radius of a circular mask that represents the valid area on the camera's sensor plane. Pixels out of this circular region are considered invalid CANONICAL_RGB_WIDTH = 1408 - total_num_images_per_camera = provider.get_num_data(stream_ids[0]) - if self.max_dataset_size == -1: - num_images_to_sample_per_camera = total_num_images_per_camera - else: - num_images_to_sample_per_camera = ( - self.max_dataset_size // (len(vrs_mps_points_triplets) * 3) - if self.include_side_cameras - else self.max_dataset_size // len(vrs_mps_points_triplets) - ) - sampling_indicies = random.sample(range(total_num_images_per_camera), num_images_to_sample_per_camera) - if not self.include_side_cameras: + + if not self.include_side_cameras: # RGB images only + if self.max_dataset_size == -1: + sampling_indices = range(provider.get_num_data(stream_ids[0])) + else: + num_images_to_sample = ( + self.max_dataset_size * total_num_rgb_images_per_recording_list[rec_i] + ) // sum(total_num_rgb_images_per_recording_list) + sampling_indices = random.sample(range(provider.get_num_data(stream_ids[0])), num_images_to_sample) aria_rgb_frames = [ to_aria_image_frame( provider, index, name_to_camera, t_world_devices, self.output_dir, camera_name=names[0] ) - for index in sampling_indicies + for index in sampling_indices ] print(f"Creating NerfStudio frames for recording {rec_i + 1}...") nerfstudio_frames["frames"] += [to_nerfstudio_frame(frame) for frame in aria_rgb_frames] @@ -365,7 +383,26 @@ def main(self) -> None: aria_rgb_frames[0].camera.width / CANONICAL_RGB_WIDTH ) # to handle both high-res 2880 x 2880 aria captures nerfstudio_frames["fisheye_crop_radius"] = rgb_valid_radius - else: + else: # include the side grayscale cameras + total_num_images_per_camera_list = [provider.get_num_data(stream_id) for stream_id in stream_ids] + if self.max_dataset_size == -1: + sampling_indices_list = [range(num_images) for num_images in total_num_images_per_camera_list] + else: + total_num_images = sum( + total_num_images_per_camera_list + ) # total number of images for this recording + num_images_to_sample = ( + self.max_dataset_size // num_recordings + ) # total number of images to sample for this recording + num_images_to_sample_per_camera_list = [ + num_images_to_sample * num // total_num_images for num in total_num_images_per_camera_list + ] + sampling_indices_list = [ + random.sample( + range(total_num_images_per_camera_list[i]), num_images_to_sample_per_camera_list[i] + ) + for i in range(3) + ] aria_all3cameras_pinhole_frames = [ [ to_aria_image_frame( @@ -377,7 +414,7 @@ def main(self) -> None: camera_name=names[i], pinhole=True, ) - for index in range(provider.get_num_data(stream_id)) + for index in sampling_indices_list[i] ] for i, stream_id in enumerate(stream_ids) ] @@ -419,7 +456,7 @@ def main(self) -> None: points_data = mps.read_global_point_cloud(str(points_path)) # type: ignore points_data = filter_points_from_confidence(points_data) points += [cast(Any, it).position_world for it in points_data] - print(len(nerfstudio_frames['frames'])) + if len(points) > 0: print("Saving found points to PLY...") print(f"Total number of points found: {len(points)} in {num_recordings} recording(s) provided") @@ -430,7 +467,7 @@ def main(self) -> None: nerfstudio_frames["ply_file_path"] = "global_points.ply" else: print("No global points found!") - + print(len(nerfstudio_frames["frames"])) # Write the json out to disk as transforms.json print("Writing transforms.json") transform_file = self.output_dir / "transforms.json" @@ -440,4 +477,4 @@ def main(self) -> None: if __name__ == "__main__": tyro.extras.set_accent_color("bright_yellow") - tyro.cli(ProcessProjectAria).main() + tyro.cli(ProcessProjectAria).main() \ No newline at end of file From 7de192216b48eef77a250abdca5f55f812f40cc8 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 26 Jul 2024 15:58:10 -0700 Subject: [PATCH 23/78] cleaned up the accelerated pil_to_numpy function --- nerfstudio/data/utils/data_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nerfstudio/data/utils/data_utils.py b/nerfstudio/data/utils/data_utils.py index bb62fdd968..0ee43378b8 100644 --- a/nerfstudio/data/utils/data_utils.py +++ b/nerfstudio/data/utils/data_utils.py @@ -29,8 +29,9 @@ def pil_to_numpy(im: Image) -> np.ndarray: im (PIL.Image.Image): The input PIL Image object. Returns: - numpy.ndarray: float 32 ndarray representing the image normalized between 0 and 1. + numpy.ndarray representing the image data. """ + # Load in image completely (PIL defaults to lazy loading) im.load() # Unpack data @@ -50,7 +51,7 @@ def pil_to_numpy(im: Image) -> np.ndarray: if s < 0: raise RuntimeError("encoder error %d in tobytes" % s) - return data / np.float32(255) + return data def get_image_mask_tensor_from_path(filepath: Union[Path, IO[bytes]], scale_factor: float = 1.0) -> torch.Tensor: From 9ceaad1458334901382ad9c9d89b7d9c08eb52c6 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 26 Jul 2024 16:00:29 -0700 Subject: [PATCH 24/78] cleaning up PR --- nerfstudio/data/pixel_samplers.py | 4 +- nerfstudio/data/utils/nerfstudio_collate.py | 2 - nerfstudio/models/bilateral_splat.py | 1110 ------------------- 3 files changed, 1 insertion(+), 1115 deletions(-) delete mode 100644 nerfstudio/models/bilateral_splat.py diff --git a/nerfstudio/data/pixel_samplers.py b/nerfstudio/data/pixel_samplers.py index c355ae7830..ad11ee4094 100644 --- a/nerfstudio/data/pixel_samplers.py +++ b/nerfstudio/data/pixel_samplers.py @@ -304,9 +304,7 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int, assert num_rays_per_batch % 2 == 0, "num_rays_per_batch must be divisible by 2" num_rays_per_image = divide_rays_per_image(num_rays_per_batch, num_images) - # print(batch.keys()) - # import time - # time.sleep(3) + if "mask" in batch: for i, num_rays in enumerate(num_rays_per_image): image_height, image_width, _ = batch["image"][i].shape diff --git a/nerfstudio/data/utils/nerfstudio_collate.py b/nerfstudio/data/utils/nerfstudio_collate.py index 9a859a7fec..8c8a633fb8 100644 --- a/nerfstudio/data/utils/nerfstudio_collate.py +++ b/nerfstudio/data/utils/nerfstudio_collate.py @@ -98,8 +98,6 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) - import warnings - warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') storage = elem.storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) diff --git a/nerfstudio/models/bilateral_splat.py b/nerfstudio/models/bilateral_splat.py deleted file mode 100644 index 6a90c337c5..0000000000 --- a/nerfstudio/models/bilateral_splat.py +++ /dev/null @@ -1,1110 +0,0 @@ -# ruff: noqa: E741 -# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Gaussian Splatting implementation that combines many recent advancements. -""" - -from __future__ import annotations - -import math -from dataclasses import dataclass, field -from typing import Dict, List, Literal, Optional, Tuple, Type, Union - -import numpy as np -import torch -from gsplat.cuda_legacy._torch_impl import quat_to_rotmat - -try: - from gsplat.rendering import rasterization -except ImportError: - print("Please install gsplat>=1.0.0") -from gsplat.cuda_legacy._wrapper import num_sh_bases -from pytorch_msssim import SSIM -from torch.nn import Parameter - -from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig -from nerfstudio.cameras.cameras import Cameras -from nerfstudio.data.scene_box import OrientedBox -from nerfstudio.engine.callbacks import ( - TrainingCallback, - TrainingCallbackAttributes, - TrainingCallbackLocation, -) -from nerfstudio.engine.optimizers import Optimizers -from nerfstudio.models.base_model import Model, ModelConfig -from nerfstudio.utils.colors import get_color -from nerfstudio.utils.misc import torch_compile -from nerfstudio.utils.rich_utils import CONSOLE -from bilags.lib_bilagrid import slice, BilateralGrid, total_variation_loss - - -def random_quat_tensor(N): - """ - Defines a random quaternion tensor of shape (N, 4) - """ - u = torch.rand(N) - v = torch.rand(N) - w = torch.rand(N) - return torch.stack( - [ - torch.sqrt(1 - u) * torch.sin(2 * math.pi * v), - torch.sqrt(1 - u) * torch.cos(2 * math.pi * v), - torch.sqrt(u) * torch.sin(2 * math.pi * w), - torch.sqrt(u) * torch.cos(2 * math.pi * w), - ], - dim=-1, - ) - - -def RGB2SH(rgb): - """ - Converts from RGB values [0,1] to the 0th spherical harmonic coefficient - """ - C0 = 0.28209479177387814 - return (rgb - 0.5) / C0 - - -def SH2RGB(sh): - """ - Converts from the 0th spherical harmonic coefficient to RGB values [0,1] - """ - C0 = 0.28209479177387814 - return sh * C0 + 0.5 - - -def resize_image(image: torch.Tensor, d: int): - """ - Downscale images using the same 'area' method in opencv - - :param image shape [H, W, C] - :param d downscale factor (must be 2, 4, 8, etc.) - - return downscaled image in shape [H//d, W//d, C] - """ - import torch.nn.functional as tf - - image = image.to(torch.float32) - weight = (1.0 / (d * d)) * torch.ones( - (1, 1, d, d), dtype=torch.float32, device=image.device - ) - return ( - tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d) - .squeeze(1) - .permute(1, 2, 0) - ) - - -@torch_compile() -def get_viewmat(optimized_camera_to_world): - """ - function that converts c2w to gsplat world2camera matrix, using compile for some speed - """ - R = optimized_camera_to_world[:, :3, :3] # 3 x 3 - T = optimized_camera_to_world[:, :3, 3:4] # 3 x 1 - # flip the z and y axes to align with gsplat conventions - R = R * torch.tensor([[[1, -1, -1]]], device=R.device, dtype=R.dtype) - # analytic matrix inverse to get world2camera matrix - R_inv = R.transpose(1, 2) - T_inv = -torch.bmm(R_inv, T) - viewmat = torch.zeros(R.shape[0], 4, 4, device=R.device, dtype=R.dtype) - viewmat[:, 3, 3] = 1.0 # homogenous - viewmat[:, :3, :3] = R_inv - viewmat[:, :3, 3:4] = T_inv - return viewmat - - -@dataclass -class BilagsModelConfig(ModelConfig): - """Splatfacto Model Config, nerfstudio's implementation of Gaussian Splatting""" - - _target: Type = field(default_factory=lambda: BilagsModel) - warmup_length: int = 500 - """period of steps where refinement is turned off""" - refine_every: int = 100 - """period of steps where gaussians are culled and densified""" - resolution_schedule: int = 3000 - """training starts at 1/d resolution, every n steps this is doubled""" - background_color: Literal["random", "black", "white"] = "random" - """Whether to randomize the background color.""" - num_downscales: int = 2 - """at the beginning, resolution is 1/2^d, where d is this number""" - cull_alpha_thresh: float = 0.1 - """threshold of opacity for culling gaussians. One can set it to a lower value (e.g. 0.005) for higher quality.""" - cull_scale_thresh: float = 0.5 - """threshold of scale for culling huge gaussians""" - continue_cull_post_densification: bool = True - """If True, continue to cull gaussians post refinement""" - reset_alpha_every: int = 30 - """Every this many refinement steps, reset the alpha""" - densify_grad_thresh: float = 0.0008 - """threshold of positional gradient norm for densifying gaussians""" - densify_size_thresh: float = 0.01 - """below this size, gaussians are *duplicated*, otherwise split""" - n_split_samples: int = 2 - """number of samples to split gaussians into""" - sh_degree_interval: int = 1000 - """every n intervals turn on another sh degree""" - cull_screen_size: float = 0.15 - """if a gaussian is more than this percent of screen space, cull it""" - split_screen_size: float = 0.05 - """if a gaussian is more than this percent of screen space, split it""" - stop_screen_size_at: int = 4000 - """stop culling/splitting at this step WRT screen size of gaussians""" - random_init: bool = False - """whether to initialize the positions uniformly randomly (not SFM points)""" - num_random: int = 50000 - """Number of gaussians to initialize if random init is used""" - random_scale: float = 10.0 - "Size of the cube to initialize random gaussians within" - ssim_lambda: float = 0.2 - """weight of ssim loss""" - stop_split_at: int = 15000 - """stop splitting at this step""" - sh_degree: int = 3 - """maximum degree of spherical harmonics to use""" - use_scale_regularization: bool = False - """If enabled, a scale regularization introduced in PhysGauss (https://xpandora.github.io/PhysGaussian/) is used for reducing huge spikey gaussians.""" - max_gauss_ratio: float = 10.0 - """threshold of ratio of gaussian max to min scale before applying regularization - loss from the PhysGaussian paper - """ - output_depth_during_training: bool = False - """If True, output depth during training. Otherwise, only output depth during evaluation.""" - rasterize_mode: Literal["classic", "antialiased"] = "classic" - """ - Classic mode of rendering will use the EWA volume splatting with a [0.3, 0.3] screen space blurring kernel. This - approach is however not suitable to render tiny gaussians at higher or lower resolution than the captured, which - results "aliasing-like" artifacts. The antialiased mode overcomes this limitation by calculating compensation factors - and apply them to the opacities of gaussians to preserve the total integrated density of splats. - - However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that - were implemented for classic mode can not render antialiased mode PLY properly without modifications. - """ - camera_optimizer: CameraOptimizerConfig = field( - default_factory=lambda: CameraOptimizerConfig(mode="off") - ) - """Config of the camera optimizer to use""" - use_bilateral_grid: bool = True - - -class BilagsModel(Model): - """Nerfstudio's implementation of Gaussian Splatting - - Args: - config: Splatfacto configuration to instantiate model - """ - - config: BilagsModelConfig - - def __init__( - self, - *args, - seed_points: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ): - self.seed_points = seed_points - super().__init__(*args, **kwargs) - - def populate_modules(self): - if self.seed_points is not None and not self.config.random_init: - means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color) - else: - means = torch.nn.Parameter( - (torch.rand((self.config.num_random, 3)) - 0.5) - * self.config.random_scale - ) - self.xys_grad_norm = None - self.max_2Dsize = None - distances, _ = self.k_nearest_sklearn(means.data, 3) - distances = torch.from_numpy(distances) - # find the average of the three nearest neighbors for each point and use that as the scale - avg_dist = distances.mean(dim=-1, keepdim=True) - scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3))) - num_points = means.shape[0] - quats = torch.nn.Parameter(random_quat_tensor(num_points)) - dim_sh = num_sh_bases(self.config.sh_degree) - - if ( - self.seed_points is not None - and not self.config.random_init - # We can have colors without points. - and self.seed_points[1].shape[0] > 0 - ): - shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda() - if self.config.sh_degree > 0: - shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255) - shs[:, 1:, 3:] = 0.0 - else: - CONSOLE.log("use color only optimization with sigmoid activation") - shs[:, 0, :3] = torch.logit(self.seed_points[1] / 255, eps=1e-10) - features_dc = torch.nn.Parameter(shs[:, 0, :]) - features_rest = torch.nn.Parameter(shs[:, 1:, :]) - else: - features_dc = torch.nn.Parameter(torch.rand(num_points, 3)) - features_rest = torch.nn.Parameter(torch.zeros((num_points, dim_sh - 1, 3))) - - opacities = torch.nn.Parameter(torch.logit(0.1 * torch.ones(num_points, 1))) - self.gauss_params = torch.nn.ParameterDict( - { - "means": means, - "scales": scales, - "quats": quats, - "features_dc": features_dc, - "features_rest": features_rest, - "opacities": opacities, - } - ) - - self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup( - num_cameras=self.num_train_data, device="cpu" - ) - - # metrics - from torchmetrics.image import PeakSignalNoiseRatio - from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity - - self.psnr = PeakSignalNoiseRatio(data_range=1.0) - self.ssim = SSIM(data_range=1.0, size_average=True, channel=3) - self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True) - self.step = 0 - - self.crop_box: Optional[OrientedBox] = None - if self.config.background_color == "random": - self.background_color = torch.tensor( - [0.1490, 0.1647, 0.2157] - ) # This color is the same as the default background color in Viser. This would only affect the background color when rendering. - else: - self.background_color = get_color(self.config.background_color) - - if self.config.use_bilateral_grid: - self.bil_grids = BilateralGrid(num=self.num_train_data) - - @property - def colors(self): - if self.config.sh_degree > 0: - return SH2RGB(self.features_dc) - else: - return torch.sigmoid(self.features_dc) - - @property - def shs_0(self): - return self.features_dc - - @property - def shs_rest(self): - return self.features_rest - - @property - def num_points(self): - return self.means.shape[0] - - @property - def means(self): - return self.gauss_params["means"] - - @property - def scales(self): - return self.gauss_params["scales"] - - @property - def quats(self): - return self.gauss_params["quats"] - - @property - def features_dc(self): - return self.gauss_params["features_dc"] - - @property - def features_rest(self): - return self.gauss_params["features_rest"] - - @property - def opacities(self): - return self.gauss_params["opacities"] - - def load_state_dict(self, dict, **kwargs): # type: ignore - # resize the parameters to match the new number of points - self.step = 30000 - if "means" in dict: - # For backwards compatibility, we remap the names of parameters from - # means->gauss_params.means since old checkpoints have that format - for p in [ - "means", - "scales", - "quats", - "features_dc", - "features_rest", - "opacities", - ]: - dict[f"gauss_params.{p}"] = dict[p] - newp = dict["gauss_params.means"].shape[0] - for name, param in self.gauss_params.items(): - old_shape = param.shape - new_shape = (newp,) + old_shape[1:] - self.gauss_params[name] = torch.nn.Parameter( - torch.zeros(new_shape, device=self.device) - ) - super().load_state_dict(dict, **kwargs) - - def k_nearest_sklearn(self, x: torch.Tensor, k: int): - """ - Find k-nearest neighbors using sklearn's NearestNeighbors. - x: The data tensor of shape [num_samples, num_features] - k: The number of neighbors to retrieve - """ - # Convert tensor to numpy array - x_np = x.cpu().numpy() - - # Build the nearest neighbors model - from sklearn.neighbors import NearestNeighbors - - nn_model = NearestNeighbors( - n_neighbors=k + 1, algorithm="auto", metric="euclidean" - ).fit(x_np) - - # Find the k-nearest neighbors - distances, indices = nn_model.kneighbors(x_np) - - # Exclude the point itself from the result and return - return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32) - - def remove_from_optim(self, optimizer, deleted_mask, new_params): - """removes the deleted_mask from the optimizer provided""" - assert len(new_params) == 1 - # assert isinstance(optimizer, torch.optim.Adam), "Only works with Adam" - - param = optimizer.param_groups[0]["params"][0] - param_state = optimizer.state[param] - del optimizer.state[param] - - # Modify the state directly without deleting and reassigning. - if "exp_avg" in param_state: - param_state["exp_avg"] = param_state["exp_avg"][~deleted_mask] - param_state["exp_avg_sq"] = param_state["exp_avg_sq"][~deleted_mask] - - # Update the parameter in the optimizer's param group. - del optimizer.param_groups[0]["params"][0] - del optimizer.param_groups[0]["params"] - optimizer.param_groups[0]["params"] = new_params - optimizer.state[new_params[0]] = param_state - - def remove_from_all_optim(self, optimizers, deleted_mask): - param_groups = self.get_gaussian_param_groups() - for group, param in param_groups.items(): - self.remove_from_optim(optimizers.optimizers[group], deleted_mask, param) - torch.cuda.empty_cache() - - def dup_in_optim(self, optimizer, dup_mask, new_params, n=2): - """adds the parameters to the optimizer""" - param = optimizer.param_groups[0]["params"][0] - param_state = optimizer.state[param] - if "exp_avg" in param_state: - repeat_dims = (n,) + tuple( - 1 for _ in range(param_state["exp_avg"].dim() - 1) - ) - param_state["exp_avg"] = torch.cat( - [ - param_state["exp_avg"], - torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat( - *repeat_dims - ), - ], - dim=0, - ) - param_state["exp_avg_sq"] = torch.cat( - [ - param_state["exp_avg_sq"], - torch.zeros_like( - param_state["exp_avg_sq"][dup_mask.squeeze()] - ).repeat(*repeat_dims), - ], - dim=0, - ) - del optimizer.state[param] - optimizer.state[new_params[0]] = param_state - optimizer.param_groups[0]["params"] = new_params - del param - - def dup_in_all_optim(self, optimizers, dup_mask, n): - param_groups = self.get_gaussian_param_groups() - for group, param in param_groups.items(): - self.dup_in_optim(optimizers.optimizers[group], dup_mask, param, n) - - def after_train(self, step: int): - assert step == self.step - # to save some training time, we no longer need to update those stats post refinement - if self.step >= self.config.stop_split_at: - return - with torch.no_grad(): - # keep track of a moving average of grad norms - visible_mask = (self.radii > 0).flatten() - grads = self.xys.absgrad[0][visible_mask].norm(dim=-1) # type: ignore - # print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}") - if self.xys_grad_norm is None: - self.xys_grad_norm = torch.zeros( - self.num_points, device=self.device, dtype=torch.float32 - ) - self.vis_counts = torch.ones( - self.num_points, device=self.device, dtype=torch.float32 - ) - assert self.vis_counts is not None - self.vis_counts[visible_mask] += 1 - self.xys_grad_norm[visible_mask] += grads - # update the max screen size, as a ratio of number of pixels - if self.max_2Dsize is None: - self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32) - newradii = self.radii.detach()[visible_mask] - self.max_2Dsize[visible_mask] = torch.maximum( - self.max_2Dsize[visible_mask], - newradii / float(max(self.last_size[0], self.last_size[1])), - ) - - def set_crop(self, crop_box: Optional[OrientedBox]): - self.crop_box = crop_box - - def set_background(self, background_color: torch.Tensor): - assert background_color.shape == (3,) - self.background_color = background_color - - def refinement_after(self, optimizers: Optimizers, step): - assert step == self.step - if self.step <= self.config.warmup_length: - return - with torch.no_grad(): - # Offset all the opacity reset logic by refine_every so that we don't - # save checkpoints right when the opacity is reset (saves every 2k) - # then cull - # only split/cull if we've seen every image since opacity reset - reset_interval = self.config.reset_alpha_every * self.config.refine_every - do_densification = ( - self.step < self.config.stop_split_at - and self.step % reset_interval - > self.num_train_data + self.config.refine_every - ) - if do_densification: - # then we densify - assert ( - self.xys_grad_norm is not None - and self.vis_counts is not None - and self.max_2Dsize is not None - ) - avg_grad_norm = ( - (self.xys_grad_norm / self.vis_counts) - * 0.5 - * max(self.last_size[0], self.last_size[1]) - ) - high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze() - splits = ( - self.scales.exp().max(dim=-1).values - > self.config.densify_size_thresh - ).squeeze() - if self.step < self.config.stop_screen_size_at: - splits |= ( - self.max_2Dsize > self.config.split_screen_size - ).squeeze() - splits &= high_grads - nsamps = self.config.n_split_samples - split_params = self.split_gaussians(splits, nsamps) - - dups = ( - self.scales.exp().max(dim=-1).values - <= self.config.densify_size_thresh - ).squeeze() - dups &= high_grads - dup_params = self.dup_gaussians(dups) - for name, param in self.gauss_params.items(): - self.gauss_params[name] = torch.nn.Parameter( - torch.cat( - [param.detach(), split_params[name], dup_params[name]], - dim=0, - ) - ) - # append zeros to the max_2Dsize tensor - self.max_2Dsize = torch.cat( - [ - self.max_2Dsize, - torch.zeros_like(split_params["scales"][:, 0]), - torch.zeros_like(dup_params["scales"][:, 0]), - ], - dim=0, - ) - - split_idcs = torch.where(splits)[0] - self.dup_in_all_optim(optimizers, split_idcs, nsamps) - - dup_idcs = torch.where(dups)[0] - self.dup_in_all_optim(optimizers, dup_idcs, 1) - - # After a guassian is split into two new gaussians, the original one should also be pruned. - splits_mask = torch.cat( - ( - splits, - torch.zeros( - nsamps * splits.sum() + dups.sum(), - device=self.device, - dtype=torch.bool, - ), - ) - ) - - deleted_mask = self.cull_gaussians(splits_mask) - elif ( - self.step >= self.config.stop_split_at - and self.config.continue_cull_post_densification - ): - deleted_mask = self.cull_gaussians() - else: - # if we donot allow culling post refinement, no more gaussians will be pruned. - deleted_mask = None - - if deleted_mask is not None: - self.remove_from_all_optim(optimizers, deleted_mask) - - if ( - self.step < self.config.stop_split_at - and self.step % reset_interval == self.config.refine_every - ): - # Reset value is set to be twice of the cull_alpha_thresh - reset_value = self.config.cull_alpha_thresh * 2.0 - self.opacities.data = torch.clamp( - self.opacities.data, - max=torch.logit( - torch.tensor(reset_value, device=self.device) - ).item(), - ) - # reset the exp of optimizer - optim = optimizers.optimizers["opacities"] - param = optim.param_groups[0]["params"][0] - param_state = optim.state[param] - param_state["exp_avg"] = torch.zeros_like(param_state["exp_avg"]) - param_state["exp_avg_sq"] = torch.zeros_like(param_state["exp_avg_sq"]) - - self.xys_grad_norm = None - self.vis_counts = None - self.max_2Dsize = None - - def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None): - """ - This function deletes gaussians with under a certain opacity threshold - extra_cull_mask: a mask indicates extra gaussians to cull besides existing culling criterion - """ - n_bef = self.num_points - # cull transparent ones - culls = ( - torch.sigmoid(self.opacities) < self.config.cull_alpha_thresh - ).squeeze() - below_alpha_count = torch.sum(culls).item() - toobigs_count = 0 - if extra_cull_mask is not None: - culls = culls | extra_cull_mask - if self.step > self.config.refine_every * self.config.reset_alpha_every: - # cull huge ones - toobigs = ( - torch.exp(self.scales).max(dim=-1).values - > self.config.cull_scale_thresh - ).squeeze() - if self.step < self.config.stop_screen_size_at: - # cull big screen space - if self.max_2Dsize is not None: - toobigs = ( - toobigs - | (self.max_2Dsize > self.config.cull_screen_size).squeeze() - ) - culls = culls | toobigs - toobigs_count = torch.sum(toobigs).item() - for name, param in self.gauss_params.items(): - self.gauss_params[name] = torch.nn.Parameter(param[~culls]) - - CONSOLE.log( - f"Culled {n_bef - self.num_points} gaussians " - f"({below_alpha_count} below alpha thresh, {toobigs_count} too bigs, {self.num_points} remaining)" - ) - - return culls - - def split_gaussians(self, split_mask, samps): - """ - This function splits gaussians that are too large - """ - n_splits = split_mask.sum().item() - CONSOLE.log( - f"Splitting {split_mask.sum().item()/self.num_points} gaussians: {n_splits}/{self.num_points}" - ) - centered_samples = torch.randn( - (samps * n_splits, 3), device=self.device - ) # Nx3 of axis-aligned scales - scaled_samples = ( - torch.exp(self.scales[split_mask].repeat(samps, 1)) * centered_samples - ) # how these scales are rotated - quats = self.quats[split_mask] / self.quats[split_mask].norm( - dim=-1, keepdim=True - ) # normalize them first - rots = quat_to_rotmat(quats.repeat(samps, 1)) # how these scales are rotated - rotated_samples = torch.bmm(rots, scaled_samples[..., None]).squeeze() - new_means = rotated_samples + self.means[split_mask].repeat(samps, 1) - # step 2, sample new colors - new_features_dc = self.features_dc[split_mask].repeat(samps, 1) - new_features_rest = self.features_rest[split_mask].repeat(samps, 1, 1) - # step 3, sample new opacities - new_opacities = self.opacities[split_mask].repeat(samps, 1) - # step 4, sample new scales - size_fac = 1.6 - new_scales = torch.log(torch.exp(self.scales[split_mask]) / size_fac).repeat( - samps, 1 - ) - self.scales[split_mask] = torch.log( - torch.exp(self.scales[split_mask]) / size_fac - ) - # step 5, sample new quats - new_quats = self.quats[split_mask].repeat(samps, 1) - out = { - "means": new_means, - "features_dc": new_features_dc, - "features_rest": new_features_rest, - "opacities": new_opacities, - "scales": new_scales, - "quats": new_quats, - } - for name, param in self.gauss_params.items(): - if name not in out: - out[name] = param[split_mask].repeat(samps, 1) - return out - - def dup_gaussians(self, dup_mask): - """ - This function duplicates gaussians that are too small - """ - n_dups = dup_mask.sum().item() - CONSOLE.log( - f"Duplicating {dup_mask.sum().item()/self.num_points} gaussians: {n_dups}/{self.num_points}" - ) - new_dups = {} - for name, param in self.gauss_params.items(): - new_dups[name] = param[dup_mask] - return new_dups - - def get_training_callbacks( - self, training_callback_attributes: TrainingCallbackAttributes - ) -> List[TrainingCallback]: - cbs = [] - cbs.append( - TrainingCallback( - [TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], self.step_cb - ) - ) - # The order of these matters - cbs.append( - TrainingCallback( - [TrainingCallbackLocation.AFTER_TRAIN_ITERATION], - self.after_train, - ) - ) - cbs.append( - TrainingCallback( - [TrainingCallbackLocation.AFTER_TRAIN_ITERATION], - self.refinement_after, - update_every_num_iters=self.config.refine_every, - args=[training_callback_attributes.optimizers], - ) - ) - return cbs - - def step_cb(self, step): - self.step = step - - def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]: - # Here we explicitly use the means, scales as parameters so that the user can override this function and - # specify more if they want to add more optimizable params to gaussians. - return { - name: [self.gauss_params[name]] - for name in [ - "means", - "scales", - "quats", - "features_dc", - "features_rest", - "opacities", - ] - } - - def get_param_groups(self) -> Dict[str, List[Parameter]]: - """Obtain the parameter groups for the optimizers - - Returns: - Mapping of different parameter groups - """ - gps = self.get_gaussian_param_groups() - gps["bil_grids"] = list(self.bil_grids.parameters()) - self.camera_optimizer.get_param_groups(param_groups=gps) - return gps - - def _get_downscale_factor(self): - if self.training: - return 2 ** max( - ( - self.config.num_downscales - - self.step // self.config.resolution_schedule - ), - 0, - ) - else: - return 1 - - def _downscale_if_required(self, image): - d = self._get_downscale_factor() - if d > 1: - return resize_image(image, d) - return image - - @staticmethod - def get_empty_outputs( - width: int, height: int, background: torch.Tensor - ) -> Dict[str, Union[torch.Tensor, List]]: - rgb = background.repeat(height, width, 1) - depth = background.new_ones(*rgb.shape[:2], 1) * 10 - accumulation = background.new_zeros(*rgb.shape[:2], 1) - return { - "rgb": rgb, - "depth": depth, - "accumulation": accumulation, - "background": background, - } - - def _get_background_color(self): - if self.config.background_color == "random": - if self.training: - background = torch.rand(3, device=self.device) - else: - background = self.background_color.to(self.device) - elif self.config.background_color == "white": - background = torch.ones(3, device=self.device) - elif self.config.background_color == "black": - background = torch.zeros(3, device=self.device) - else: - raise ValueError(f"Unknown background color {self.config.background_color}") - return background - - def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: - """Takes in a Ray Bundle and returns a dictionary of outputs. - - Args: - ray_bundle: Input bundle of rays. This raybundle should have all the - needed information to compute the outputs. - - Returns: - Outputs of model. (ie. rendered colors) - """ - if not isinstance(camera, Cameras): - print("Called get_outputs with not a camera") - return {} - - if self.training: - assert camera.shape[0] == 1, "Only one camera at a time" - optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera) - else: - optimized_camera_to_world = camera.camera_to_worlds - - # cropping - if self.crop_box is not None and not self.training: - crop_ids = self.crop_box.within(self.means).squeeze() - if crop_ids.sum() == 0: - return self.get_empty_outputs( - int(camera.width.item()), - int(camera.height.item()), - self.background_color, - ) - else: - crop_ids = None - - if crop_ids is not None: - opacities_crop = self.opacities[crop_ids] - means_crop = self.means[crop_ids] - features_dc_crop = self.features_dc[crop_ids] - features_rest_crop = self.features_rest[crop_ids] - scales_crop = self.scales[crop_ids] - quats_crop = self.quats[crop_ids] - else: - opacities_crop = self.opacities - means_crop = self.means - features_dc_crop = self.features_dc - features_rest_crop = self.features_rest - scales_crop = self.scales - quats_crop = self.quats - - colors_crop = torch.cat( - (features_dc_crop[:, None, :], features_rest_crop), dim=1 - ) - - BLOCK_WIDTH = ( - 16 # this controls the tile size of rasterization, 16 is a good default - ) - camera_scale_fac = self._get_downscale_factor() - camera.rescale_output_resolution(1 / camera_scale_fac) - viewmat = get_viewmat(optimized_camera_to_world) - K = camera.get_intrinsics_matrices().cuda() - W, H = int(camera.width.item()), int(camera.height.item()) - self.last_size = (H, W) - camera.rescale_output_resolution(camera_scale_fac) # type: ignore - - # apply the compensation of screen space blurring to gaussians - if self.config.rasterize_mode not in ["antialiased", "classic"]: - raise ValueError("Unknown rasterize_mode: %s", self.config.rasterize_mode) - - if self.config.output_depth_during_training or not self.training: - render_mode = "RGB+ED" - else: - render_mode = "RGB" - - if self.config.sh_degree > 0: - sh_degree_to_use = min( - self.step // self.config.sh_degree_interval, self.config.sh_degree - ) - else: - colors_crop = torch.sigmoid(colors_crop) - sh_degree_to_use = None - - render, alpha, info = rasterization( - means=means_crop, - quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True), - scales=torch.exp(scales_crop), - opacities=torch.sigmoid(opacities_crop).squeeze(-1), - colors=colors_crop, - viewmats=viewmat, # [1, 4, 4] - Ks=K, # [1, 3, 3] - width=W, - height=H, - tile_size=BLOCK_WIDTH, - packed=False, - near_plane=0.01, - far_plane=1e10, - render_mode=render_mode, - sh_degree=sh_degree_to_use, - sparse_grad=False, - absgrad=True, - rasterize_mode=self.config.rasterize_mode, - # set some threshold to disregrad small gaussians for faster rendering. - # radius_clip=3.0, - ) - if self.training and info["means2d"].requires_grad: - info["means2d"].retain_grad() - self.xys = info["means2d"] # [1, N, 2] - self.radii = info["radii"][0] # [N] - alpha = alpha[:, ...] - - background = self._get_background_color() - rgb = render[:, ..., :3] + (1 - alpha) * background - rgb = torch.clamp(rgb, 0.0, 1.0) - - # apply bilateral grid - if self.config.use_bilateral_grid and self.training: - if camera.metadata is not None and "cam_idx" in camera.metadata: - cam_idx = camera.metadata["cam_idx"] - if cam_idx != 0: - # make xy grid - grid_y, grid_x = torch.meshgrid( - torch.linspace(0, 1.0, H), - torch.linspace(0, 1.0, W), - indexing="ij", - ) - grid_xy = ( - torch.stack([grid_x, grid_y], dim=-1) - .unsqueeze(0) - .to(self.device) - ) - - # prepare grid idx - # grid_idx = ( - # torch.ones((H, W), dtype=torch.long, device=self.device) * cam_idx - # ) - # grid_idx = grid_idx.unsqueeze(-1) # [H, W, 1] - # process rgb - out = slice( - bil_grids=self.bil_grids, - rgb=rgb, - xy=grid_xy, - grid_idx=torch.tensor( - cam_idx, device=self.device, dtype=torch.long - ), - ) - rgb = out["rgb"] - - if render_mode == "RGB+ED": - depth_im = render[:, ..., 3:4] - depth_im = torch.where( - alpha > 0, depth_im, depth_im.detach().max() - ).squeeze(0) - else: - depth_im = None - - if background.shape[0] == 3 and not self.training: - background = background.expand(H, W, 3) - - return { - "rgb": rgb.squeeze(0), # type: ignore - "depth": depth_im, # type: ignore - "accumulation": alpha.squeeze(0), # type: ignore - "background": background, # type: ignore - } # type: ignore - - def get_gt_img(self, image: torch.Tensor): - """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose - - Args: - image: tensor.Tensor in type uint8 or float32 - """ - if image.dtype == torch.uint8: - image = image.float() / 255.0 - gt_img = self._downscale_if_required(image) - return gt_img.to(self.device) - - def composite_with_background(self, image, background) -> torch.Tensor: - """Composite the ground truth image with a background color when it has an alpha channel. - - Args: - image: the image to composite - background: the background color - """ - if image.shape[2] == 4: - alpha = image[..., -1].unsqueeze(-1).repeat((1, 1, 3)) - return alpha * image[..., :3] + (1 - alpha) * background - else: - return image - - def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: - """Compute and returns metrics. - - Args: - outputs: the output to compute loss dict to - batch: ground truth batch corresponding to outputs - """ - gt_rgb = self.composite_with_background( - self.get_gt_img(batch["image"]), outputs["background"] - ) - metrics_dict = {} - predicted_rgb = outputs["rgb"] - metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb) - - metrics_dict["gaussian_count"] = self.num_points - - self.camera_optimizer.get_metrics_dict(metrics_dict) - return metrics_dict - - def get_loss_dict( - self, outputs, batch, metrics_dict=None - ) -> Dict[str, torch.Tensor]: - """Computes and returns the losses dict. - - Args: - outputs: the output to compute loss dict to - batch: ground truth batch corresponding to outputs - metrics_dict: dictionary of metrics, some of which we can use for loss - """ - gt_img = self.composite_with_background( - self.get_gt_img(batch["image"]), outputs["background"] - ) - pred_img = outputs["rgb"] - - # Set masked part of both ground-truth and rendered image to black. - # This is a little bit sketchy for the SSIM loss. - if "mask" in batch: - # batch["mask"] : [H, W, 1] - mask = self._downscale_if_required(batch["mask"]) - mask = mask.to(self.device) - assert mask.shape[:2] == gt_img.shape[:2] == pred_img.shape[:2] - gt_img = gt_img * mask - pred_img = pred_img * mask - - Ll1 = torch.abs(gt_img - pred_img).mean() - simloss = 1 - self.ssim( - gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...] - ) - if self.config.use_scale_regularization and self.step % 10 == 0: - scale_exp = torch.exp(self.scales) - scale_reg = ( - torch.maximum( - scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1), - torch.tensor(self.config.max_gauss_ratio), - ) - - self.config.max_gauss_ratio - ) - scale_reg = 0.1 * scale_reg.mean() - else: - scale_reg = torch.tensor(0.0).to(self.device) - - loss_dict = { - "main_loss": (1 - self.config.ssim_lambda) * Ll1 - + self.config.ssim_lambda * simloss, - "scale_reg": scale_reg, - } - - if self.training: - # Add loss from camera optimizer - self.camera_optimizer.get_loss_dict(loss_dict) - loss_dict["tv_loss"] = 10 * total_variation_loss(self.bil_grids.grids) - - return loss_dict - - @torch.no_grad() - def get_outputs_for_camera( - self, camera: Cameras, obb_box: Optional[OrientedBox] = None - ) -> Dict[str, torch.Tensor]: - """Takes in a camera, generates the raybundle, and computes the output of the model. - Overridden for a camera-based gaussian model. - - Args: - camera: generates raybundle - """ - assert camera is not None, "must provide camera to gaussian model" - self.set_crop(obb_box) - outs = self.get_outputs(camera.to(self.device)) - return outs # type: ignore - - def get_image_metrics_and_images( - self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] - ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: - """Writes the test image outputs. - - Args: - image_idx: Index of the image. - step: Current step. - batch: Batch of data. - outputs: Outputs of the model. - - Returns: - A dictionary of metrics. - """ - gt_rgb = self.composite_with_background( - self.get_gt_img(batch["image"]), outputs["background"] - ) - predicted_rgb = outputs["rgb"] - - combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) - - # Switch images from [H, W, C] to [1, C, H, W] for metrics computations - gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...] - predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...] - - psnr = self.psnr(gt_rgb, predicted_rgb) - ssim = self.ssim(gt_rgb, predicted_rgb) - lpips = self.lpips(gt_rgb, predicted_rgb) - - # all of these metrics will be logged as scalars - metrics_dict = {"psnr": float(psnr.item()), "ssim": float(ssim)} # type: ignore - metrics_dict["lpips"] = float(lpips) - - images_dict = {"img": combined_rgb} - - return metrics_dict, images_dict From 4147a6aba29ab6692e1eed658fcdb612b1056a20 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 26 Jul 2024 16:05:03 -0700 Subject: [PATCH 25/78] this commit was used to generate the time metrics and profiling metrics --- .../data/datamanagers/base_datamanager.py | 43 +++++++++---------- nerfstudio/data/datasets/base_dataset.py | 23 +++++----- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 21b52b25d1..9feaa0f3b6 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -342,7 +342,7 @@ class VanillaDataManagerConfig(DataManagerConfig): prefetch_factor: int = 4 # increasing prefetch_factor was not beneficial """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" - dataloader_num_workers: int = 4 + dataloader_num_workers: int = 1 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" use_ray_train_dataloader: bool = True @@ -462,25 +462,22 @@ def _get_batch_list(self, indices=None): else 4 * int(self.num_image_load_threads) ) num_threads = min(num_threads, multiprocessing.cpu_count() - 1) - num_threads = max(num_threads, 1) - # print('num_threads', num_threads) # prints 16 - - + num_threads = max(num_threads, 1) # print('num_threads', num_threads) # prints 16 # NB: this is I/O heavy because we are going to disk and reading an image filename # hence multi-threaded inside the worker - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - for idx in indices: - res = executor.submit(self.input_dataset.__getitem__, idx) - results.append(res) - - # for res in tqdm(results, desc='_get_batch_list'): - results = tqdm(results) # does not effect times, tested many times - for res in results: - batch_list.append(res.result()) + # with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + # for idx in indices: + # res = executor.submit(self.input_dataset.__getitem__, idx) + # results.append(res) + + # # for res in tqdm(results, desc='_get_batch_list'): + # results = tqdm(results) # does not effect times, tested many times + # for res in results: + # batch_list.append(res.result()) - # for idx in tqdm(indices): # this is slower compared to using threads - # batch_list.append(self.input_dataset.__getitem__(idx)) + for idx in tqdm(indices): # this is slower compared to using threads + batch_list.append(self.input_dataset.__getitem__(idx)) return batch_list def _get_collated_batch(self, indices=None): @@ -537,13 +534,13 @@ def __iter__(self): else: # get a total of 'num_images_to_sample_from' image indices image_indices = worker_indices[:self.num_images_to_sample_from] # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. - # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: - # with record_function("process_images"): - collated_batch = self._get_collated_batch(image_indices) - # with open('_get_batch_list_profile.txt', 'w') as f: - # f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) - # f.write("\n\nMemory Usage:\n") - # f.write(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=20)) + with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_stack=True,) as prof: + with record_function("process_images"): + collated_batch = self._get_collated_batch(image_indices) + + with open('_get_batch_list_profile.txt', 'w') as f: + f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + i += 1 """ Here, the variable 'batch' refers to the output of our pixel sampler. diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index e535bffb38..56c3827fde 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -71,7 +71,7 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = def __len__(self): return len(self._dataparser_outputs.image_filenames) - def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: + def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.float32]: """Returns the image of shape (H, W, 3 or 4). Args: @@ -86,10 +86,12 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: width, height = pil_image.size newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) pil_image = pil_image.resize(newsize, resample=Image.Resampling.BILINEAR) - image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "float32" + with record_function("pil_to_numpy()"): + image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" if len(image.shape) == 2: image = image[:, :, None].repeat(3, axis=2) assert len(image.shape) == 3 + assert image.dtype == np.uint8 assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is in correct." return image @@ -99,7 +101,8 @@ def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image Args: image_idx: The image index in the dataset. """ - image = torch.from_numpy(self.get_numpy_image(image_idx)) + with record_function("divide by 255.0 and convert to float32"): + image = torch.from_numpy(self.get_numpy_image(image_idx) / np.float32(255.0)) if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: assert (self._dataparser_outputs.alpha_color >= 0).all() and ( self._dataparser_outputs.alpha_color <= 1 @@ -131,12 +134,13 @@ def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "fl image_idx: The image index in the dataset. image_type: the type of images returned """ - if image_type == "float32": - image = self.get_image_float32(image_idx) - elif image_type == "uint8": - image = self.get_image_uint8(image_idx) - else: - raise NotImplementedError(f"image_type (={image_type}) getter was not implemented, use uint8 or float32") + with record_function("divide by 255.0 and convert to float32"): + if image_type == "float32": + image = self.get_image_float32(image_idx) + elif image_type == "uint8": + image = self.get_image_uint8(image_idx) + else: + raise NotImplementedError(f"image_type (={image_type}) getter was not implemented, use uint8 or float32") data = {"image_idx": image_idx, "image": image} if self._dataparser_outputs.mask_filenames is not None: @@ -177,4 +181,3 @@ def image_filenames(self) -> List[Path]: """ return self._dataparser_outputs.image_filenames - \ No newline at end of file From 5a55b7a0c689de31218a9fd0680e541630358309 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 26 Jul 2024 16:30:49 -0700 Subject: [PATCH 26/78] REAL commit used to run tests --- nerfstudio/data/datasets/base_dataset.py | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index 56c3827fde..d87669a1c0 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -71,7 +71,7 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = def __len__(self): return len(self._dataparser_outputs.image_filenames) - def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.float32]: + def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: """Returns the image of shape (H, W, 3 or 4). Args: @@ -86,8 +86,7 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.float32]: width, height = pil_image.size newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) pil_image = pil_image.resize(newsize, resample=Image.Resampling.BILINEAR) - with record_function("pil_to_numpy()"): - image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" + image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" if len(image.shape) == 2: image = image[:, :, None].repeat(3, axis=2) assert len(image.shape) == 3 @@ -101,8 +100,12 @@ def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image Args: image_idx: The image index in the dataset. """ - with record_function("divide by 255.0 and convert to float32"): - image = torch.from_numpy(self.get_numpy_image(image_idx) / np.float32(255.0)) + with record_function("pil_to_numpy()"): + image = self.get_numpy_image(image_idx) + with record_function("divide by 255.0 + convert to float32"): + image = image / np.float32(255) + with record_function("torch.from_numpy()"): + image = torch.from_numpy(image) if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: assert (self._dataparser_outputs.alpha_color >= 0).all() and ( self._dataparser_outputs.alpha_color <= 1 @@ -134,13 +137,12 @@ def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "fl image_idx: The image index in the dataset. image_type: the type of images returned """ - with record_function("divide by 255.0 and convert to float32"): - if image_type == "float32": - image = self.get_image_float32(image_idx) - elif image_type == "uint8": - image = self.get_image_uint8(image_idx) - else: - raise NotImplementedError(f"image_type (={image_type}) getter was not implemented, use uint8 or float32") + if image_type == "float32": + image = self.get_image_float32(image_idx) + elif image_type == "uint8": + image = self.get_image_uint8(image_idx) + else: + raise NotImplementedError(f"image_type (={image_type}) getter was not implemented, use uint8 or float32") data = {"image_idx": image_idx, "image": image} if self._dataparser_outputs.mask_filenames is not None: From 78f02e6d8aaa8147eb06e8555e5986a56de91085 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 15 Aug 2024 06:27:21 -0700 Subject: [PATCH 27/78] testing with nerfacto-big --- nerfstudio/configs/method_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index c15233feae..d1deb594d0 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -126,7 +126,7 @@ max_num_iterations=100000, mixed_precision=True, pipeline=VanillaPipelineConfig( - datamanager=ParallelDataManagerConfig( + datamanager=VanillaDataManagerConfig( dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=8192, eval_num_rays_per_batch=4096, From 19bc4b5b5f68c1126b670dd6dc707c24d7503cee Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 15 Aug 2024 06:36:36 -0700 Subject: [PATCH 28/78] generated RayBundle collate and converting images from uint8s to float32 on GPU tests --- .../data/datamanagers/base_datamanager.py | 88 +++++++++++++++---- .../datamanagers/full_images_datamanager.py | 22 +++-- nerfstudio/data/datasets/base_dataset.py | 9 +- nerfstudio/data/utils/data_utils.py | 5 +- nerfstudio/data/utils/nerfstudio_collate.py | 14 +-- nerfstudio/models/nerfacto.py | 19 +++- nerfstudio/pipelines/base_pipeline.py | 2 + 7 files changed, 118 insertions(+), 41 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 9feaa0f3b6..d124e4a502 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -90,7 +90,6 @@ def variable_res_collate(batch: List[Dict]) -> Dict: # now that iteration is complete, the image data items can be removed from the batch for key in topop: del data[key] - new_batch = nerfstudio_collate(batch) new_batch["image"] = images new_batch.update(imgdata_lists) @@ -98,6 +97,31 @@ def variable_res_collate(batch: List[Dict]) -> Dict: return new_batch +def ray_collate(batch: List[RayBundle]): + # start = time.time() + ray_bundle_list, batch_list = list(zip(*batch)) + combined_metadata = {} + if "fisheye_crop_radius" in ray_bundle_list[0].metadata: + combined_metadata["fisheye_crop_radius"] = ray_bundle_list[0].metadata["fisheye_crop_radius"] + if "directions_norm" in ray_bundle_list[0].metadata: + combined_metadata["directions_norm"] = torch.cat([ray_bundle_i.metadata["directions_norm"] for ray_bundle_i in ray_bundle_list], dim=0) + + concatenated_ray_bundle = RayBundle( + origins=torch.cat([ray_bundle_i.origins for ray_bundle_i in ray_bundle_list], dim=0), + directions=torch.cat([ray_bundle_i.directions for ray_bundle_i in ray_bundle_list], dim=0), + pixel_area=torch.cat([ray_bundle_i.pixel_area for ray_bundle_i in ray_bundle_list], dim=0), + camera_indices=torch.cat([ray_bundle_i.camera_indices for ray_bundle_i in ray_bundle_list], dim=0), + metadata=combined_metadata, + ) + concatenated_batch = { + "image" : torch.cat([batch_i["image"] for batch_i in batch_list], dim=0), + "indices": torch.cat([batch_i["indices"] for batch_i in batch_list], dim=0), + } + # end = time.time() + # print((end - start) * 1000) + return [[concatenated_ray_bundle, concatenated_batch]] + + @dataclass class DataManagerConfig(InstantiateConfig): """Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers; @@ -339,16 +363,16 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - prefetch_factor: int = 4 # increasing prefetch_factor was not beneficial + prefetch_factor: int = 5 # prefetch_factor of 16 does well, but any that is equal train_num_times_to_repeat_images is good """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" - dataloader_num_workers: int = 1 + dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" use_ray_train_dataloader: bool = True """Allows parallelization of the dataloading process with multiple workers.""" cache_binaries: bool = True - """If True, caches the images as binary strings to RAM""" + """When enabled, cache raw image files to RAM""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -406,10 +430,13 @@ def __init__( self.num_images_to_sample_from = num_images_to_sample_from self.num_times_to_repeat_images = num_times_to_repeat_images self.device = device + # self.collate_fn = variable_res_collate # variable_res_collate avoids collating images, which is much faster than `nerfstudio_collate` self.collate_fn = collate_fn + print("collate_fn", self.collate_fn) + print("self.device", self.device) self.num_image_load_threads = num_image_load_threads # kwargs.get("num_workers", 4) # nb only 4 in defaults self.exclude_batch_keys_from_device = exclude_batch_keys_from_device - + # print("self.exclude_batch_keys_from_device", self.exclude_batch_keys_from_device) # usually prints ['image'] self.datamanager_config = datamanager_config self.pixel_sampler: PixelSampler = None self.ray_generator: RayGenerator = None @@ -476,7 +503,7 @@ def _get_batch_list(self, indices=None): # for res in results: # batch_list.append(res.result()) - for idx in tqdm(indices): # this is slower compared to using threads + for idx in tqdm(indices): # this is slower compared to using threads, but using this allows us to profile __getitem__ batch_list.append(self.input_dataset.__getitem__(idx)) return batch_list @@ -491,13 +518,20 @@ def _get_collated_batch(self, indices=None): with record_function("_get_batch_list"): batch_list = self._get_batch_list(indices=indices) # # print(type(batch_list[0])) # prints - # print(self.collate_fn) # prints nerfstudio_collate - with record_function("nerfstudio_collate"): + # print(self.collate_fn) # prints nerfstudio_collate on mainRGB, but prints variable_res_collate if all3cameras + with record_function("collate_function"): collated_batch = self.collate_fn(batch_list) with record_function("sending to GPU"): collated_batch = get_dict_to_torch( collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device ) + # with record_function("converting to float32 + divide255 on GPU"): + # collated_batch["image"] = convert_uint8_to_float32(collated_batch["image"]) + # batch_list = self._get_batch_list(indices=indices) + # collated_batch = self.collate_fn(batch_list) + # collated_batch = get_dict_to_torch( + # collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device + # ) return collated_batch def __iter__(self): @@ -534,13 +568,13 @@ def __iter__(self): else: # get a total of 'num_images_to_sample_from' image indices image_indices = worker_indices[:self.num_images_to_sample_from] # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. - with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_stack=True,) as prof: - with record_function("process_images"): - collated_batch = self._get_collated_batch(image_indices) - - with open('_get_batch_list_profile.txt', 'w') as f: - f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) - + + collated_batch = self._get_collated_batch(image_indices) + # with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_stack=True,) as prof: + # with record_function("process_images"): + # collated_batch = self._get_collated_batch(image_indices) + # with open('_get_batch_list_profile.txt', 'w') as f: + # f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) i += 1 """ Here, the variable 'batch' refers to the output of our pixel sampler. @@ -550,9 +584,11 @@ def __iter__(self): What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask, and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol) """ - batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. + batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. + # the returned batch also somehow moves the images from the CPU to the GPU + # collated_batch["image"].get_device() will return ray_indices = batch["indices"] - ray_bundle = self.ray_generator(ray_indices) + ray_bundle = self.ray_generator(ray_indices) # the ray_bundle is on the GPU, but yield ray_bundle, batch @@ -618,7 +654,7 @@ def __init__( cameras = self.train_dataparser_outputs.cameras if len(cameras) > 1: for i in range(1, len(cameras)): - if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height: + if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height or True: # or True: # ADJUST COLLATE FN HERE CONSOLE.print("Variable resolution, using variable_res_collate") self.config.collate_fn = variable_res_collate break @@ -701,6 +737,7 @@ def setup_train(self): collate_fn=self.config.collate_fn, cache_all_n_shard_per_worker=False, ) + # This one uses identity collate self.ray_dataloader = torch.utils.data.DataLoader( self.raybatch_stream, batch_size=1, @@ -712,7 +749,20 @@ def setup_train(self): collate_fn=identity, pin_memory_device=self.device, # did not actually speed up my implementation ) - self.iter_train_image_dataloader = None + + # # this one uses ray_collate + # self.ray_dataloader = torch.utils.data.DataLoader( + # self.raybatch_stream, + # batch_size=4, + # num_workers=self.config.dataloader_num_workers, + # prefetch_factor=self.config.prefetch_factor, + # shuffle=False, + # pin_memory=True, + # # Our dataset does batching / collation + # collate_fn=ray_collate, + # pin_memory_device=self.device, # did not actually speed up my implementation + # ) + self.iter_train_raybundles = iter(self.ray_dataloader) else: self.iter_train_raybundles = None diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 0ce84e4b77..52919a7972 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -47,14 +47,6 @@ from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE -class ImageBatchStream(torch.utils.data.IterableDataset): - def __init__( - self, - - ): - return - - # def @dataclass class FullImageDatamanagerConfig(DataManagerConfig): @@ -557,3 +549,17 @@ def _undistort_image( raise NotImplementedError("Only perspective and fisheye cameras are supported") return K, image, mask + + +@dataclass +class FullImageBatchStreamConfig(DataManagerConfig): + _target: Type = field(default_factory=lambda: ImageBatchStream) + +class ImageBatchStream(torch.utils.data.IterableDataset): + def __init__( + self, + + ): + return + + # def \ No newline at end of file diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index d87669a1c0..f38b8294db 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -86,12 +86,13 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: width, height = pil_image.size newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) pil_image = pil_image.resize(newsize, resample=Image.Resampling.BILINEAR) - image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" + # image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" + image = np.array(pil_image, copy=False) if len(image.shape) == 2: image = image[:, :, None].repeat(3, axis=2) assert len(image.shape) == 3 assert image.dtype == np.uint8 - assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is in correct." + assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is incorrect." return image def get_image_float32(self, image_idx: int) -> Float[Tensor, "image_height image_width num_channels"]: @@ -119,7 +120,7 @@ def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_w Args: image_idx: The image index in the dataset. """ - image = torch.from_numpy(self.get_numpy_image(image_idx).astype(np.uint8)) + image = torch.from_numpy(self.get_numpy_image(image_idx)) # removed astype(np.uint8) because get_numpy_image returns uint8 if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: assert (self._dataparser_outputs.alpha_color >= 0).all() and ( self._dataparser_outputs.alpha_color <= 1 @@ -130,7 +131,7 @@ def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_w image = torch.clamp(image, min=0, max=255).to(torch.uint8) return image - def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "float32") -> Dict: + def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "uint8") -> Dict: """Returns the ImageDataset data as a dictionary. Args: diff --git a/nerfstudio/data/utils/data_utils.py b/nerfstudio/data/utils/data_utils.py index 0ee43378b8..7be9d6f6dc 100644 --- a/nerfstudio/data/utils/data_utils.py +++ b/nerfstudio/data/utils/data_utils.py @@ -20,7 +20,7 @@ import numpy as np import torch from PIL import Image - +from torch.profiler import record_function def pil_to_numpy(im: Image) -> np.ndarray: """Converts a PIL Image object to a NumPy array. @@ -32,7 +32,8 @@ def pil_to_numpy(im: Image) -> np.ndarray: numpy.ndarray representing the image data. """ # Load in image completely (PIL defaults to lazy loading) - im.load() + with record_function("im.load()"): + im.load() # Unpack data e = Image._getencoder(im.mode, "raw", im.mode) diff --git a/nerfstudio/data/utils/nerfstudio_collate.py b/nerfstudio/data/utils/nerfstudio_collate.py index 8c8a633fb8..f05ea80882 100644 --- a/nerfstudio/data/utils/nerfstudio_collate.py +++ b/nerfstudio/data/utils/nerfstudio_collate.py @@ -25,6 +25,7 @@ import torch.utils.data from nerfstudio.cameras.cameras import Cameras +from torch.profiler import profile, record_function NERFSTUDIO_COLLATE_ERR_MSG_FORMAT = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts, lists or anything in {}; found {}" @@ -94,12 +95,13 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None - if torch.utils.data.get_worker_info() is not None: - # If we're in a background process, concatenate directly into a - # shared memory tensor to avoid an extra copy - numel = sum(x.numel() for x in batch) - storage = elem.storage()._new_shared(numel, device=elem.device) - out = elem.new(storage).resize_(len(batch), *list(elem.size())) + with record_function("creating shared memory"): + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem.storage()._new_shared(numel, device=elem.device) + out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": if elem_type.__name__ in ("ndarray", "memmap"): diff --git a/nerfstudio/models/nerfacto.py b/nerfstudio/models/nerfacto.py index bfccfd8797..e52db1dd50 100644 --- a/nerfstudio/models/nerfacto.py +++ b/nerfstudio/models/nerfacto.py @@ -46,7 +46,7 @@ from nerfstudio.model_components.shaders import NormalsShader from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils import colormaps - +from torch.profiler import profile, record_function, ProfilerActivity @dataclass class NerfactoModelConfig(ModelConfig): @@ -362,7 +362,22 @@ def get_metrics_dict(self, outputs, batch): def get_loss_dict(self, outputs, batch, metrics_dict=None): loss_dict = {} - image = batch["image"].to(self.device) + image = batch["image"]#.to(self.device) + # Start profiling + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=True) as prof: + with record_function("image_normalization"): + image = image / torch.tensor(255, dtype=torch.float32, device=self.device) + # image = image.float() / 255.0 + + # Write profiler results to a file + profile_path = "profiler_results.txt" + with open(profile_path, "w") as f: + f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + image = image.to(self.device) pred_rgb, gt_rgb = self.renderer_rgb.blend_background_for_loss_computation( pred_image=outputs["rgb"], pred_accumulation=outputs["accumulation"], diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 731f214e77..4fd2de0df4 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -298,6 +298,8 @@ def get_train_loss_dict(self, step: int): step: current iteration step to update sampler if using DDP (distributed) """ ray_bundle, batch = self.datamanager.next_train(step) + # print("ray_bundle.origins.get_device()", ray_bundle.origins.get_device()) # prints 0 (it's on CUDA) + # print("batch['image'].get_device()", batch["image"].get_device()) # prints -1 (it's on CPU) model_outputs = self._model(ray_bundle) # train distributed data parallel model if world_size > 1 metrics_dict = self.model.get_metrics_dict(model_outputs, batch) loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict) From 9245d05da2d7844bde86f97d611024927827d5ef Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Tue, 20 Aug 2024 01:27:46 -0700 Subject: [PATCH 29/78] updating nerfacto to support uint8 easily, will need to figure out a way to contain this within the datamanager API --- nerfstudio/models/nerfacto.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/nerfstudio/models/nerfacto.py b/nerfstudio/models/nerfacto.py index e52db1dd50..71e60fbb2c 100644 --- a/nerfstudio/models/nerfacto.py +++ b/nerfstudio/models/nerfacto.py @@ -362,20 +362,22 @@ def get_metrics_dict(self, outputs, batch): def get_loss_dict(self, outputs, batch, metrics_dict=None): loss_dict = {} - image = batch["image"]#.to(self.device) + image = batch["image"].to(self.device) # Start profiling - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - profile_memory=True, - with_stack=True) as prof: - with record_function("image_normalization"): - image = image / torch.tensor(255, dtype=torch.float32, device=self.device) - # image = image.float() / 255.0 + if image.dtype == torch.uint8: + image = image / torch.tensor(255, dtype=torch.float32, device=self.device) + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + # record_shapes=True, + # profile_memory=True, + # with_stack=True) as prof: + # with record_function("image_normalization"): + # image = image / torch.tensor(255, dtype=torch.float32, device=self.device) + # # image = image.float() / 255.0 - # Write profiler results to a file - profile_path = "profiler_results.txt" - with open(profile_path, "w") as f: - f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + # # Write profiler results to a file + # profile_path = "profiler_results.txt" + # with open(profile_path, "w") as f: + # f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) image = image.to(self.device) pred_rgb, gt_rgb = self.renderer_rgb.blend_background_for_loss_computation( From 3124c1438c22c198813c32626c934e0dbdbd25e1 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Tue, 20 Aug 2024 02:02:42 -0700 Subject: [PATCH 30/78] datamanager updates, both splat and nerf --- .../data/datamanagers/base_datamanager.py | 11 +-- .../datamanagers/full_images_datamanager.py | 77 +++++++++++++++++-- 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index d124e4a502..41643b08bb 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -343,7 +343,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """Number of rays per batch to use per training iteration.""" train_num_images_to_sample_from: int = 100 # usually -1 """Number of images to sample during training iteration.""" - train_num_times_to_repeat_images: int = 5 # usually -1 + train_num_times_to_repeat_images: int = 10 # usually -1 """When not training on all images, number of iterations before picking new images. If -1, never pick new images.""" eval_num_rays_per_batch: int = 1024 @@ -363,7 +363,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - prefetch_factor: int = 5 # prefetch_factor of 16 does well, but any that is equal train_num_times_to_repeat_images is good + prefetch_factor: int = train_num_times_to_repeat_images # prefetch_factor of 16 does well, but any that is equal train_num_times_to_repeat_images is good """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" dataloader_num_workers: int = 4 @@ -517,16 +517,13 @@ def _get_collated_batch(self, indices=None): """ with record_function("_get_batch_list"): batch_list = self._get_batch_list(indices=indices) - # # print(type(batch_list[0])) # prints # print(self.collate_fn) # prints nerfstudio_collate on mainRGB, but prints variable_res_collate if all3cameras with record_function("collate_function"): collated_batch = self.collate_fn(batch_list) with record_function("sending to GPU"): collated_batch = get_dict_to_torch( - collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device + collated_batch, device=self.device, #exclude=self.exclude_batch_keys_from_device ) - # with record_function("converting to float32 + divide255 on GPU"): - # collated_batch["image"] = convert_uint8_to_float32(collated_batch["image"]) # batch_list = self._get_batch_list(indices=indices) # collated_batch = self.collate_fn(batch_list) # collated_batch = get_dict_to_torch( @@ -588,7 +585,7 @@ def __iter__(self): # the returned batch also somehow moves the images from the CPU to the GPU # collated_batch["image"].get_device() will return ray_indices = batch["indices"] - ray_bundle = self.ray_generator(ray_indices) # the ray_bundle is on the GPU, but + ray_bundle = self.ray_generator(ray_indices) # the ray_bundle is on the GPU, but batch["image"] is on the CPU yield ray_bundle, batch diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 52919a7972..f9ce5779f3 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -551,15 +551,78 @@ def _undistort_image( return K, image, mask -@dataclass -class FullImageBatchStreamConfig(DataManagerConfig): - _target: Type = field(default_factory=lambda: ImageBatchStream) - +## Let's implement a parallelized splat dataloader! +def undistort_idx(idx: int, dataset: TDataset, config: FullImageDatamanagerConfig) -> Dict[str, torch.Tensor]: + data = dataset.get_data(idx, image_type=config.cache_images_type) + camera = dataset.cameras[idx].reshape(()) + assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( + f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' + f'does not match the camera parameters ({camera.width.item(), camera.height.item()})' + ) + if camera.distortion_params is None or torch.all(camera.distortion_params == 0): + return data + K = camera.get_intrinsics_matrices().numpy() + distortion_params = camera.distortion_params.numpy() + image = data["image"].numpy() + + K, image, mask = _undistort_image(camera, distortion_params, data, image, K) + data["image"] = torch.from_numpy(image) + if mask is not None: + data["mask"] = mask + + dataset.cameras.fx[idx] = float(K[0, 0]) + dataset.cameras.fy[idx] = float(K[1, 1]) + dataset.cameras.cx[idx] = float(K[0, 2]) + dataset.cameras.cy[idx] = float(K[1, 2]) + dataset.cameras.width[idx] = image.shape[1] + dataset.cameras.height[idx] = image.shape[0] + return data + +import math class ImageBatchStream(torch.utils.data.IterableDataset): + """ + A datamanager that outputs full images and cameras instead of raybundles. This makes the + datamanager more lightweight since we don't have to do generate rays. Useful for full-image + training e.g. rasterization pipelines + """ + config: FullImageDatamanagerConfig + dataset: TDataset # can be a train dataset or an eval dataset + def __init__( self, - + datamanager_config: DataManagerConfig, + input_dataset: TDataset, ): - return + self.config = datamanager_config + self.dataset = input_dataset - # def \ No newline at end of file + def __iter__(self): + dataset_indices = list( + range(len(self.input_dataset)) + ) + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: # if we have multiple processes + per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) + slice_start = worker_info.id * per_worker + else: # we only have a single process + per_worker = len(self.input_dataset) + slice_start = 0 + worker_indices = dataset_indices[ + slice_start : slice_start + per_worker + ] # the indices of the datapoints in the dataset this worker will load + r = random.Random(self.config.train_cameras_sampling_seed) + idx = 0 + while True: + if idx == per_worker: # if we've iterated through all the worker's partition of images, we need to reshuffle + r.shuffle(worker_indices) + idx = 0 + idx += 1 + if self.config.cache_images != "disk": + # TODO: somehow use the fact we've stored everything in self.cached_train and cached_eval + + yield self.cached_train[worker_indices[idx]], self.dataset.cameras[idx].reshape(()) + else: # when the images are only stored on disk, we need to undistort them and load them into memory + data = undistort_idx(idx, self.dataset, self.config) + + + From afb06129b4da487c7312d13d3ab44bf32f8dab0d Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Tue, 20 Aug 2024 02:08:40 -0700 Subject: [PATCH 31/78] must use writeable arrays because torch requires them --- nerfstudio/data/datasets/base_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index f38b8294db..c00b8481d6 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -46,7 +46,7 @@ class InputDataset(Dataset): exclude_batch_keys_from_device: List[str] = ["image", "mask"] cameras: Cameras - def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_images: bool = True): + def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_images: bool = False): super().__init__() self._dataparser_outputs = dataparser_outputs self.scale_factor = scale_factor @@ -86,8 +86,8 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: width, height = pil_image.size newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) pil_image = pil_image.resize(newsize, resample=Image.Resampling.BILINEAR) - # image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" - image = np.array(pil_image, copy=False) + image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" + # image = np.array(pil_image, copy=False) if len(image.shape) == 2: image = image[:, :, None].repeat(3, axis=2) assert len(image.shape) == 3 @@ -131,7 +131,7 @@ def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_w image = torch.clamp(image, min=0, max=255).to(torch.uint8) return image - def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "uint8") -> Dict: + def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "float32") -> Dict: """Returns the ImageDataset data as a dictionary. Args: From 288a740fce2af7ef9006cd49539af5224cbbe178 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Wed, 21 Aug 2024 23:12:29 -0700 Subject: [PATCH 32/78] cleaned up base_dataset, added pickle to utils, more code in full_image, and cleaner desc for base_datamanager --- .../data/datamanagers/base_datamanager.py | 2 +- .../datamanagers/full_images_datamanager.py | 51 ++++++++++++++----- nerfstudio/data/datasets/base_dataset.py | 1 - nerfstudio/data/utils/data_utils.py | 4 ++ 4 files changed, 44 insertions(+), 14 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 41643b08bb..46ca983557 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -370,7 +370,7 @@ class VanillaDataManagerConfig(DataManagerConfig): """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" use_ray_train_dataloader: bool = True - """Allows parallelization of the dataloading process with multiple workers.""" + """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" cache_binaries: bool = True """When enabled, cache raw image files to RAM""" diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index f9ce5779f3..bf4c246a45 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -44,6 +44,7 @@ from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset +from nerfstudio.data.utils.data_utils import identity_collate from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -63,8 +64,8 @@ class FullImageDatamanagerConfig(DataManagerConfig): new images. If -1, never pick new images.""" eval_image_indices: Optional[Tuple[int, ...]] = (0,) """Specifies the image indices to use during eval; if None, uses all.""" - cache_images: Literal["cpu", "gpu"] = "gpu" - """Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device.""" + cache_images: Literal["cpu", "gpu", "disk"] = "gpu" + """Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device. If "disk", keeps images on disk. """ cache_images_type: Literal["uint8", "float32"] = "float32" """The image type returned from manager, caching images in uint8 saves memory""" max_thread_workers: Optional[int] = None @@ -79,7 +80,14 @@ class FullImageDatamanagerConfig(DataManagerConfig): fps_reset_every: int = 100 """The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every samples from the pool of all training cameras without replacement before a new round of sampling starts.""" - + use_image_train_dataloader: bool = cache_images == "disk" + """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" + dataloader_num_workers: int = 4 + """The number of workers performing the dataloading from either disk/RAM, which + includes collating, pixel sampling, unprojecting, ray generation etc.""" + prefetch_factor: int = 10 + """The limit number of batches a worker will start loading once an iterator is created. + More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" class FullImageDatamanager(DataManager, Generic[TDataset]): """ @@ -121,6 +129,7 @@ def __init__( self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") self.train_dataset = self.create_train_dataset() self.eval_dataset = self.create_eval_dataset() + # print(type(self.train_dataset)) # prints InputDataset if len(self.train_dataset) > 500 and self.config.cache_images == "gpu": CONSOLE.print( "Train dataset has over 500 images, overriding cache_images to cpu", @@ -138,6 +147,19 @@ def __init__( self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] assert len(self.train_unseen_cameras) > 0, "No data found in dataset" + if self.config.cache_images == "disk": + self.imagebatch_stream = ImageBatchStream( + input_dataset=self.train_dataset, + datamanager_config=self.config, + + ) + self.image_dataloader = torch.utils.data.DataLoader( + self.imagebatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + collate_fn=identity_collate, + pin_memory_device=self.device, + ) super().__init__() def sample_train_cameras(self): @@ -383,7 +405,7 @@ def _undistort_image( mask = None if camera.camera_type.item() == CameraType.PERSPECTIVE.value: assert distortion_params[3] == 0, ( - "We doesn't support the 4th Brown parameter for image undistortion, " + "We don't support the 4th Brown parameter for image undistortion, " "Only k1, k2, k3, p1, p2 can be non-zero." ) distortion_params = np.array( @@ -553,6 +575,7 @@ def _undistort_image( ## Let's implement a parallelized splat dataloader! def undistort_idx(idx: int, dataset: TDataset, config: FullImageDatamanagerConfig) -> Dict[str, torch.Tensor]: + """Undistorts an image to one taken by a linear (pinhole) camera model and updates the dataset's camera intrinsics to pinhole""" data = dataset.get_data(idx, image_type=config.cache_images_type) camera = dataset.cameras[idx].reshape(()) assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( @@ -581,7 +604,7 @@ def undistort_idx(idx: int, dataset: TDataset, config: FullImageDatamanagerConfi import math class ImageBatchStream(torch.utils.data.IterableDataset): """ - A datamanager that outputs full images and cameras instead of raybundles. This makes the + A wrapper of InputDataset that outputs undistorted full images and cameras instead of raybundles. This makes the datamanager more lightweight since we don't have to do generate rays. Useful for full-image training e.g. rasterization pipelines """ @@ -611,18 +634,22 @@ def __iter__(self): slice_start : slice_start + per_worker ] # the indices of the datapoints in the dataset this worker will load r = random.Random(self.config.train_cameras_sampling_seed) - idx = 0 + i = 0 # i refers to how many times this worker has returned an undistorted image-and-camera view while True: - if idx == per_worker: # if we've iterated through all the worker's partition of images, we need to reshuffle + if i == per_worker: # if we've iterated through all the worker's partition of images, we need to reshuffle r.shuffle(worker_indices) - idx = 0 - idx += 1 + i = 0 + idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve if self.config.cache_images != "disk": # TODO: somehow use the fact we've stored everything in self.cached_train and cached_eval - yield self.cached_train[worker_indices[idx]], self.dataset.cameras[idx].reshape(()) + yield self.cached_train[idx], self.dataset.cameras[idx].reshape(()) else: # when the images are only stored on disk, we need to undistort them and load them into memory data = undistort_idx(idx, self.dataset, self.config) - - + camera = self.dataset.cameras[idx : idx + 1].to(self.device) + if camera.metadata is None: + camera.metadata = {} + camera.metadata["cam_idx"] = idx + i += 1 + yield camera, data diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index c00b8481d6..2c7a1b1925 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -87,7 +87,6 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) pil_image = pil_image.resize(newsize, resample=Image.Resampling.BILINEAR) image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" - # image = np.array(pil_image, copy=False) if len(image.shape) == 2: image = image[:, :, None].repeat(3, axis=2) assert len(image.shape) == 3 diff --git a/nerfstudio/data/utils/data_utils.py b/nerfstudio/data/utils/data_utils.py index 7be9d6f6dc..a3944de0b7 100644 --- a/nerfstudio/data/utils/data_utils.py +++ b/nerfstudio/data/utils/data_utils.py @@ -117,3 +117,7 @@ def get_depth_image_from_path( image = image.astype(np.float64) * scale_factor image = cv2.resize(image, (width, height), interpolation=interpolation) return torch.from_numpy(image[:, :, np.newaxis]) + +def identity_collate(x): + """This function does nothing but serves to help our dataloaders have a pickleable function, as lambdas are not pickleable""" + return x \ No newline at end of file From 2fd08625c93e6d3937f6081e81a9368840c81580 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 23 Aug 2024 03:40:44 -0700 Subject: [PATCH 33/78] lots of process on a parallel FullImageDatamanger --- nerfstudio/configs/method_configs.py | 5 +- .../data/datamanagers/base_datamanager.py | 1 + .../datamanagers/full_images_datamanager.py | 119 ++++++++++++------ nerfstudio/models/splatfacto.py | 2 +- nerfstudio/pipelines/base_pipeline.py | 5 + 5 files changed, 95 insertions(+), 37 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index d1deb594d0..1c6da80e9d 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -27,7 +27,7 @@ from nerfstudio.configs.base_config import ViewerConfig from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig -from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig +from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig, ParallelFullImageDatamanager from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig @@ -37,6 +37,7 @@ from nerfstudio.data.dataparsers.phototourism_dataparser import PhototourismDataParserConfig from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig +from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.datasets.depth_dataset import DepthDataset from nerfstudio.data.datasets.sdf_dataset import SDFDataset from nerfstudio.data.datasets.semantic_dataset import SemanticDataset @@ -601,6 +602,8 @@ mixed_precision=False, pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( + _target=ParallelFullImageDatamanager[InputDataset], + # dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", ), diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 46ca983557..d77cf211cb 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -206,6 +206,7 @@ def __init__(self): self.train_count = 0 self.eval_count = 0 if self.train_dataset and self.test_mode != "inference": + # print(self.setup_train) # prints self.setup_train() if self.eval_dataset and self.test_mode != "inference": self.setup_eval() diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index bf4c246a45..be797eb879 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -85,7 +85,7 @@ class FullImageDatamanagerConfig(DataManagerConfig): dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = 10 + prefetch_factor: int = 1 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" @@ -147,19 +147,6 @@ def __init__( self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] assert len(self.train_unseen_cameras) > 0, "No data found in dataset" - if self.config.cache_images == "disk": - self.imagebatch_stream = ImageBatchStream( - input_dataset=self.train_dataset, - datamanager_config=self.config, - - ) - self.image_dataloader = torch.utils.data.DataLoader( - self.imagebatch_stream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - collate_fn=identity_collate, - pin_memory_device=self.device, - ) super().__init__() def sample_train_cameras(self): @@ -222,7 +209,6 @@ def _load_images( dataset = self.eval_dataset else: assert_never(split) - def undistort_idx(idx: int) -> Dict[str, torch.Tensor]: data = dataset.get_data(idx, image_type=self.config.cache_images_type) camera = dataset.cameras[idx].reshape(()) @@ -365,7 +351,6 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: # Make sure to re-populate the unseen cameras list if we have exhausted it if len(self.train_unseen_cameras) == 0: self.train_unseen_cameras = self.sample_train_cameras() - data = self.cached_train[image_idx] data["image"] = data["image"].to(self.device) @@ -574,9 +559,9 @@ def _undistort_image( ## Let's implement a parallelized splat dataloader! -def undistort_idx(idx: int, dataset: TDataset, config: FullImageDatamanagerConfig) -> Dict[str, torch.Tensor]: +def undistort_idx(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: """Undistorts an image to one taken by a linear (pinhole) camera model and updates the dataset's camera intrinsics to pinhole""" - data = dataset.get_data(idx, image_type=config.cache_images_type) + data = dataset.get_data(idx, image_type) camera = dataset.cameras[idx].reshape(()) assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' @@ -604,22 +589,23 @@ def undistort_idx(idx: int, dataset: TDataset, config: FullImageDatamanagerConfi import math class ImageBatchStream(torch.utils.data.IterableDataset): """ - A wrapper of InputDataset that outputs undistorted full images and cameras instead of raybundles. This makes the + A wrapper of InputDataset that outputs undistorted full images and cameras. This makes the datamanager more lightweight since we don't have to do generate rays. Useful for full-image training e.g. rasterization pipelines """ - config: FullImageDatamanagerConfig - dataset: TDataset # can be a train dataset or an eval dataset def __init__( self, datamanager_config: DataManagerConfig, input_dataset: TDataset, + device, ): self.config = datamanager_config - self.dataset = input_dataset - + self.input_dataset = input_dataset + self.device = device + def __iter__(self): + # print(self.input_dataset.cameras.device) prints cpu dataset_indices = list( range(len(self.input_dataset)) ) @@ -634,22 +620,85 @@ def __iter__(self): slice_start : slice_start + per_worker ] # the indices of the datapoints in the dataset this worker will load r = random.Random(self.config.train_cameras_sampling_seed) - i = 0 # i refers to how many times this worker has returned an undistorted image-and-camera view + r.shuffle(worker_indices) + i = 0 # i refers to what image index we are outputting: i=0 => we are yielding our first image,camera while True: - if i == per_worker: # if we've iterated through all the worker's partition of images, we need to reshuffle + if i % per_worker == 0: # if we've iterated through all the worker's partition of images, we need to reshuffle r.shuffle(worker_indices) i = 0 idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve - if self.config.cache_images != "disk": - # TODO: somehow use the fact we've stored everything in self.cached_train and cached_eval - - yield self.cached_train[idx], self.dataset.cameras[idx].reshape(()) - else: # when the images are only stored on disk, we need to undistort them and load them into memory - data = undistort_idx(idx, self.dataset, self.config) - camera = self.dataset.cameras[idx : idx + 1].to(self.device) - if camera.metadata is None: - camera.metadata = {} - camera.metadata["cam_idx"] = idx + data = undistort_idx(idx, self.input_dataset, self.config.cache_images_type) + camera = self.input_dataset.cameras[idx : idx + 1]#.to(self.device) + if camera.metadata is None: + camera.metadata = {} + camera.metadata["cam_idx"] = idx i += 1 + if torch.sum(camera.camera_to_worlds) == 0: + print(i, camera.camera_to_worlds, "YOYO INSIDE IMAGEBATCHSTREAM") yield camera, data +class ParallelFullImageDatamanager(FullImageDatamanager, Generic[TDataset]): + def __init__( + self, + config: FullImageDatamanagerConfig, + device: Union[torch.device, str] = "cpu", + test_mode: Literal["test", "val", "inference"] = "val", + world_size: int = 1, + local_rank: int = 0, + **kwargs + ): + import torch.multiprocessing as mp + mp.set_start_method("spawn") + super().__init__( + config=config, + device=device, + test_mode=test_mode, + world_size=world_size, + local_rank=local_rank, + **kwargs + ) + + def setup_train(self): + self.train_imagebatch_stream = ImageBatchStream( + input_dataset=self.train_dataset, + datamanager_config=self.config, + device=self.device, + ) + self.train_image_dataloader = torch.utils.data.DataLoader( + self.train_imagebatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + collate_fn=identity_collate, + pin_memory_device=self.device, + ) + self.iter_train_image_dataloader = iter(self.train_image_dataloader) + + def setup_eval(self): + self.eval_imagebatch_stream = ImageBatchStream( + input_dataset=self.eval_dataset, + datamanager_config=self.config, + device=self.device, + ) + self.eval_image_dataloader = torch.utils.data.DataLoader( + self.eval_imagebatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + collate_fn=identity_collate, + pin_memory_device=self.device, + ) + self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) # these things output tuples + + @property + def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: + return self.iter_eval_image_dataloader + + def next_train(self, step: int) -> Tuple[Cameras, Dict]: + self.train_count += 1 + camera, data = next(self.iter_train_image_dataloader)[0] + return camera, data + + def next_eval(self, step: int) -> Tuple[Cameras, Dict]: + self.eval_count += 1 + camera, data = next(self.iter_train_image_dataloader)[0] + return camera, data + \ No newline at end of file diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 702bdc5073..e5c30fe2f4 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -724,7 +724,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: render_mode = "RGB+ED" else: render_mode = "RGB" - + # breakpoint() if self.config.sh_degree > 0: sh_degree_to_use = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) else: diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 4fd2de0df4..2a25e9352c 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -298,6 +298,11 @@ def get_train_loss_dict(self, step: int): step: current iteration step to update sampler if using DDP (distributed) """ ray_bundle, batch = self.datamanager.next_train(step) + # print(type(ray_bundle), type(batch)) + if torch.sum(ray_bundle.camera_to_worlds) == 0: + print("YOYOYO WE INSIDE THE PIPELINE", step, ray_bundle.camera_to_worlds) + # breakpoint() + ray_bundle = ray_bundle.to(self.device) # print("ray_bundle.origins.get_device()", ray_bundle.origins.get_device()) # prints 0 (it's on CUDA) # print("batch['image'].get_device()", batch["image"].get_device()) # prints -1 (it's on CPU) model_outputs = self._model(ray_bundle) # train distributed data parallel model if world_size > 1 From 846e2f3d11baea7a8f3b89dc010a616e9fae1658 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 24 Aug 2024 04:47:28 -0700 Subject: [PATCH 34/78] can train big splats with pre-assertion hack or ROI hack and 0 workers --- .../datamanagers/full_images_datamanager.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index be797eb879..23125d737e 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -82,7 +82,7 @@ class FullImageDatamanagerConfig(DataManagerConfig): samples from the pool of all training cameras without replacement before a new round of sampling starts.""" use_image_train_dataloader: bool = cache_images == "disk" """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" - dataloader_num_workers: int = 4 + dataloader_num_workers: int = 0 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" prefetch_factor: int = 1 @@ -393,6 +393,9 @@ def _undistort_image( "We don't support the 4th Brown parameter for image undistortion, " "Only k1, k2, k3, p1, p2 can be non-zero." ) + #print(distortion_params) # [ 0.05517609 -0.07427584 0. 0. -0.00026702 -0.00060216] + # we rearrange the distortion parameters because OpenCV expects the order (k1, k2, p1, p2, k3) + # see https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html distortion_params = np.array( [ distortion_params[0], @@ -411,13 +414,18 @@ def _undistort_image( K[1, 2] = K[1, 2] - 0.5 if np.any(distortion_params): newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0) + breakpoint() image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore + # print("1:", image.shape) # prints (960, 540, 3) else: newK = K roi = 0, 0, image.shape[1], image.shape[0] # crop the image and update the intrinsics accordingly x, y, w, h = roi - image = image[y : y + h, x : x + w] + # print(x, y, w, h) # prints 0, 0, 539, 959 + # image = image[y : y + h, x : x + w] + # print("2:", image.shape) # prints (959, 539, 3) + if "depth_image" in data: data["depth_image"] = data["depth_image"][y : y + h, x : x + w] if "mask" in data: @@ -554,7 +562,7 @@ def _undistort_image( K = undist_K.numpy() else: raise NotImplementedError("Only perspective and fisheye cameras are supported") - + # print("final:", image.shape, camera.width, camera.height) # prints 'final: (959, 539, 3) tensor([540]) tensor([960])' return K, image, mask @@ -563,17 +571,22 @@ def undistort_idx(idx: int, dataset: TDataset, image_type: Literal["uint8", "flo """Undistorts an image to one taken by a linear (pinhole) camera model and updates the dataset's camera intrinsics to pinhole""" data = dataset.get_data(idx, image_type) camera = dataset.cameras[idx].reshape(()) + # dataset.cameras.width[idx] = data["image"].shape[1] + # dataset.cameras.height[idx] = data["image"].shape[0] + if idx == 48: + print("beginning", camera.width, camera.height) assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' - f'does not match the camera parameters ({camera.width.item(), camera.height.item()})' + f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' ) if camera.distortion_params is None or torch.all(camera.distortion_params == 0): return data K = camera.get_intrinsics_matrices().numpy() distortion_params = camera.distortion_params.numpy() image = data["image"].numpy() - K, image, mask = _undistort_image(camera, distortion_params, data, image, K) + # print(image.shape[1]) # outputs 539 + # print(cameras[48].reshape(()).width.item()) # outputs 540 data["image"] = torch.from_numpy(image) if mask is not None: data["mask"] = mask @@ -584,6 +597,8 @@ def undistort_idx(idx: int, dataset: TDataset, image_type: Literal["uint8", "flo dataset.cameras.cy[idx] = float(K[1, 2]) dataset.cameras.width[idx] = image.shape[1] dataset.cameras.height[idx] = image.shape[0] + if idx == 48: + print("ending", camera.width, camera.height) return data import math From 8fb0b4d5795bbb474c6654088562a645bee6f631 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 26 Aug 2024 23:39:59 -0700 Subject: [PATCH 35/78] fixed all undistortion issues with ParallelImageDatamanager --- .../datamanagers/full_images_datamanager.py | 60 +++++++++++++++---- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 23125d737e..ced5315060 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -414,7 +414,6 @@ def _undistort_image( K[1, 2] = K[1, 2] - 0.5 if np.any(distortion_params): newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0) - breakpoint() image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore # print("1:", image.shape) # prints (960, 540, 3) else: @@ -423,7 +422,9 @@ def _undistort_image( # crop the image and update the intrinsics accordingly x, y, w, h = roi # print(x, y, w, h) # prints 0, 0, 539, 959 - # image = image[y : y + h, x : x + w] + image = image[y : y + h, x : x + w] + newK[0, 2] -= x + newK[1, 2] -= y # print("2:", image.shape) # prints (959, 539, 3) if "depth_image" in data: @@ -568,12 +569,13 @@ def _undistort_image( ## Let's implement a parallelized splat dataloader! def undistort_idx(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: - """Undistorts an image to one taken by a linear (pinhole) camera model and updates the dataset's camera intrinsics to pinhole""" + """Undistorts an image to one taken by a linear (pinhole) camera model and updates the dataset's camera intrinsics to a linear camera model""" data = dataset.get_data(idx, image_type) camera = dataset.cameras[idx].reshape(()) # dataset.cameras.width[idx] = data["image"].shape[1] # dataset.cameras.height[idx] = data["image"].shape[0] if idx == 48: + # breakpoint() print("beginning", camera.width, camera.height) assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' @@ -590,17 +592,51 @@ def undistort_idx(idx: int, dataset: TDataset, image_type: Literal["uint8", "flo data["image"] = torch.from_numpy(image) if mask is not None: data["mask"] = mask - - dataset.cameras.fx[idx] = float(K[0, 0]) - dataset.cameras.fy[idx] = float(K[1, 1]) - dataset.cameras.cx[idx] = float(K[0, 2]) - dataset.cameras.cy[idx] = float(K[1, 2]) - dataset.cameras.width[idx] = image.shape[1] - dataset.cameras.height[idx] = image.shape[0] + # dataset.cameras.fx[idx] = float(K[0, 0]) + # dataset.cameras.fy[idx] = float(K[1, 1]) + # dataset.cameras.cx[idx] = float(K[0, 2]) + # dataset.cameras.cy[idx] = float(K[1, 2]) + # dataset.cameras.width[idx] = image.shape[1] + # dataset.cameras.height[idx] = image.shape[0] + # dataset.cameras.distortion_params = None if idx == 48: print("ending", camera.width, camera.height) return data +def undistort_view(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: + """Undistorts an image to one taken by a linear (pinhole) camera model and returns a new Camera with these updated intrinsics + Note: this method does not modify the dataset's attributes at all. + + Returns: The undistorted data (image, depth, mask, etc.) and the new linear Camera object + """ + data = dataset.get_data(idx, image_type) + camera = dataset.cameras[idx].reshape(()) + assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( + f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' + f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' + ) + if camera.distortion_params is None or torch.all(camera.distortion_params == 0): + return data + K = camera.get_intrinsics_matrices().numpy() + distortion_params = camera.distortion_params.numpy() + image = data["image"].numpy() + K, image, mask = _undistort_image(camera, distortion_params, data, image, K) + data["image"] = torch.from_numpy(image) + if mask is not None: + data["mask"] = mask + + # create a new Camera + new_camera = Cameras( + camera_to_worlds=camera.camera_to_worlds.unsqueeze(0), + fx=torch.Tensor([[float(K[0, 0])]]), + fy=torch.Tensor([[float(K[1, 1])]]), + cx=torch.Tensor([[float(K[0, 2])]]), + cy=torch.Tensor([[float(K[1, 2])]]), + width=torch.Tensor([[image.shape[1]]]).to(torch.int32), + height=torch.Tensor([[image.shape[0]]]).to(torch.int32), + ) + return data, new_camera + import math class ImageBatchStream(torch.utils.data.IterableDataset): """ @@ -642,8 +678,8 @@ def __iter__(self): r.shuffle(worker_indices) i = 0 idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve - data = undistort_idx(idx, self.input_dataset, self.config.cache_images_type) - camera = self.input_dataset.cameras[idx : idx + 1]#.to(self.device) + data, camera = undistort_view(idx, self.input_dataset, self.config.cache_images_type) + camera2 = self.input_dataset.cameras[idx : idx + 1]#.to(self.device) if camera.metadata is None: camera.metadata = {} camera.metadata["cam_idx"] = idx From ce3f83fec7ca9337143c6c780e0bdad504cec456 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 30 Aug 2024 18:41:59 -0700 Subject: [PATCH 36/78] adding some downsampling and parallel tests with splatfacto! --- nerfstudio/configs/method_configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 1c6da80e9d..e235c46a1e 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -603,8 +603,8 @@ pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( _target=ParallelFullImageDatamanager[InputDataset], - # dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), - dataparser=NerfstudioDataParserConfig(load_3D_points=True), + dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), + # dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", ), model=SplatfactoModelConfig(), From 8ab996394e30bc70d9dfbfc25606a7936740ec87 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 30 Aug 2024 19:03:22 -0700 Subject: [PATCH 37/78] deleted commented code in dataloaders.py and added bugfix to shuffling --- .../datamanagers/full_images_datamanager.py | 6 +- nerfstudio/data/utils/dataloaders.py | 153 ------------------ 2 files changed, 3 insertions(+), 156 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index ced5315060..2b88ba629b 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -82,10 +82,10 @@ class FullImageDatamanagerConfig(DataManagerConfig): samples from the pool of all training cameras without replacement before a new round of sampling starts.""" use_image_train_dataloader: bool = cache_images == "disk" """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" - dataloader_num_workers: int = 0 + dataloader_num_workers: int = 8 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = 1 + prefetch_factor: int = 2 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" @@ -674,7 +674,7 @@ def __iter__(self): r.shuffle(worker_indices) i = 0 # i refers to what image index we are outputting: i=0 => we are yielding our first image,camera while True: - if i % per_worker == 0: # if we've iterated through all the worker's partition of images, we need to reshuffle + if i >= len(worker_indices): # if we've iterated through all the worker's partition of images, we need to reshuffle r.shuffle(worker_indices) i = 0 idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 8c75566f48..8ef6df1353 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -175,159 +175,6 @@ def __iter__(self): pass - - -# class RayBatchStream(torch.utils.data.IterableDataset): -# def __init__( -# self, -# input_dataset: Dataset, -# num_images_to_sample_from: int = -1, -# device: Union[torch.device, str] = "cpu", -# collate_fn: Callable[[Any], Any] = nerfstudio_collate, -# exclude_batch_keys_from_device: Optional[List[str]] = None, -# num_image_load_threads : int = 2, -# cache_all_n_shard_per_worker : bool = True, -# ): -# if exclude_batch_keys_from_device is None: -# exclude_batch_keys_from_device = ["image"] -# self.input_dataset = input_dataset -# assert isinstance(self.input_dataset, Sized) - -# # super().__init__(dataset=dataset, **kwargs) # This will set self.dataset - -# # self.num_times_to_repeat_images = num_times_to_repeat_images -# # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) -# # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from -# self.num_images_to_sample_from = num_images_to_sample_from -# self.device = device -# self.collate_fn = collate_fn -# # self.num_workers = kwargs.get("num_workers", 32) # nb only 4 in defaults -# self.num_image_load_threads = num_image_load_threads #kwargs.get("num_workers", 4) # nb only 4 in defaults -# self.exclude_batch_keys_from_device = exclude_batch_keys_from_device - -# self.pixel_sampler = None -# self.ray_generator = None -# self._cached_collated_batch = None -# self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker - -# def _get_batch_list(self, indices=None): -# """Returns a list of batches from the dataset attribute.""" - -# assert isinstance(self.input_dataset, Sized) -# if not indices: -# indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) -# # indices = range(len(self.input_dataset)) -# batch_list = [] -# results = [] - -# # num_threads = int(self.num_ds_load_threads) * 4 -# num_threads = ( -# int(self.num_image_load_threads) if not self.cache_all_n_shard_per_worker -# else 4 * int(self.num_image_load_threads)) -# num_threads = min(num_threads, multiprocessing.cpu_count() - 1) -# num_threads = max(num_threads, 1) -# # print('num_threads', num_threads) - -# # NB: this is I/O heavy, hence multi-threaded inside the worker -# from tqdm.auto import tqdm -# with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: -# for idx in indices: -# res = executor.submit(self.input_dataset.__getitem__, idx) -# results.append(res) - -# # for res in track(results, description="Loading data batch", transient=True): -# # for res in tqdm(results, desc='_get_batch_list'): -# if self.cache_all_n_shard_per_worker: -# results = tqdm(results) -# for res in results: -# batch_list.append(res.result()) -# return batch_list - -# # def _get_pixel_sampler(self, dataset: 'TDataset', num_rays_per_batch: int) -> PixelSampler: -# # """copy-pasta from VanillaDataManager.""" -# # from nerfstudio.cameras.cameras import Cameras, CameraType -# # from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig - -# # if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: -# # return PatchPixelSamplerConfig().setup( -# # patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch -# # ) -# # is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() -# # if is_equirectangular.any(): -# # CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") - -# # fisheye_crop_radius = None -# # if dataset.cameras.metadata is not None: -# # fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") - -# # return self.datamanager_config.pixel_sampler.setup( -# # is_equirectangular=is_equirectangular, -# # num_rays_per_batch=num_rays_per_batch, -# # fisheye_crop_radius=fisheye_crop_radius, -# # ) - -# def _get_collated_batch(self, indices=None): -# """Returns a collated batch.""" -# batch_list = self._get_batch_list(indices=indices) -# # print('running collate_fn', self.collate_fn) -# collated_batch = self.collate_fn(batch_list) -# # print('done collate_fn') -# # assert False, (self.exclude_batch_keys_from_device, collated_batch) -# collated_batch = get_dict_to_torch( -# collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device -# ) -# # print('done get_dict_to_torch') -# # print('_get_collated_batch') -# return collated_batch - -# def __iter__(self): -# # Set up stuff now that we're in the worker process -# if self.cache_all_n_shard_per_worker: -# this_indices = list(range(len(self.input_dataset))) -# worker_info = torch.utils.data.get_worker_info() -# if worker_info is None: -# print('TODO log. only single worker not sharding!') -# worker_id = -1 -# else: -# # assign this worker a deterministic uniformly sampled slice -# # of the dataset -# import math -# per_worker = int(math.ceil(len(this_indices) / float(worker_info.num_workers))) -# r = random.Random(1337) -# r.shuffle(this_indices) -# worker_id = worker_info.id -# slice_start = worker_id * per_worker -# this_indices = this_indices[slice_start:slice_start+per_worker] -# print(f'Worker ID {worker_id} working on {len(this_indices)} indices') - -# import time -# start = time.time() -# print(f"Worker ID {worker_id} caching collated batch ...") -# self._cached_collated_batch = self._get_collated_batch(indices=this_indices) -# print(f"Worker ID {worker_id} cached collated batch in {time.time()-start} sec ...") - -# if self.pixel_sampler is None: -# self.pixel_sampler = self._get_pixel_sampler( -# self.input_dataset, -# self.datamanager_config.train_num_rays_per_batch) -# if self.ray_generator is None: -# self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) - -# # if self._cached_collated_batch is None: -# # self._cached_collated_batch = self._get_collated_batch() -# # print('did _cached_collated_batch') -# while True: -# if self._cached_collated_batch is None: -# collated_batch = self._get_collated_batch() -# else: -# collated_batch = self._cached_collated_batch -# # batch = self.pixel_sampler.sample(self._cached_collated_batch) -# batch = self.pixel_sampler.sample(collated_batch) -# ray_indices = batch["indices"] -# ray_bundle = self.ray_generator(ray_indices) -# yield ray_bundle, batch - - class EvalDataloader(DataLoader): """Evaluation dataloader base class From c9e16bffd920a417d26892b45e89b71db4df50e9 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 1 Sep 2024 00:28:04 -0700 Subject: [PATCH 38/78] testing splatfacto-big --- nerfstudio/configs/method_configs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index e235c46a1e..53c5e86a63 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -654,7 +654,12 @@ max_num_iterations=30000, mixed_precision=False, pipeline=VanillaPipelineConfig( + # datamanager=FullImageDatamanagerConfig( + # dataparser=NerfstudioDataParserConfig(load_3D_points=True), + # cache_images_type="uint8", + # ), datamanager=FullImageDatamanagerConfig( + _target=ParallelFullImageDatamanager[InputDataset], dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", ), From ddac38d7a984fddc1f673f4449733305e5fa2683 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 1 Sep 2024 00:33:52 -0700 Subject: [PATCH 39/78] cleaned up base_pipeline.py --- nerfstudio/pipelines/base_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 2a25e9352c..4618f9740c 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -299,8 +299,8 @@ def get_train_loss_dict(self, step: int): """ ray_bundle, batch = self.datamanager.next_train(step) # print(type(ray_bundle), type(batch)) - if torch.sum(ray_bundle.camera_to_worlds) == 0: - print("YOYOYO WE INSIDE THE PIPELINE", step, ray_bundle.camera_to_worlds) + # if torch.sum(ray_bundle.camera_to_worlds) == 0: # I only used this to test the splatfacto fullimage datamanager + # print("YOYOYO WE INSIDE THE PIPELINE", step, ray_bundle.camera_to_worlds) # breakpoint() ray_bundle = ray_bundle.to(self.device) # print("ray_bundle.origins.get_device()", ray_bundle.origins.get_device()) # prints 0 (it's on CUDA) From 443719a33fec799a232ba51b34ae11db83f74186 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 1 Sep 2024 00:34:44 -0700 Subject: [PATCH 40/78] cleaned up base_pipeline.py ACTUALLY THIS TIME, forgot to save last time --- nerfstudio/pipelines/base_pipeline.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 4618f9740c..f29b38a4e0 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -298,13 +298,7 @@ def get_train_loss_dict(self, step: int): step: current iteration step to update sampler if using DDP (distributed) """ ray_bundle, batch = self.datamanager.next_train(step) - # print(type(ray_bundle), type(batch)) - # if torch.sum(ray_bundle.camera_to_worlds) == 0: # I only used this to test the splatfacto fullimage datamanager - # print("YOYOYO WE INSIDE THE PIPELINE", step, ray_bundle.camera_to_worlds) - # breakpoint() ray_bundle = ray_bundle.to(self.device) - # print("ray_bundle.origins.get_device()", ray_bundle.origins.get_device()) # prints 0 (it's on CUDA) - # print("batch['image'].get_device()", batch["image"].get_device()) # prints -1 (it's on CPU) model_outputs = self._model(ray_bundle) # train distributed data parallel model if world_size > 1 metrics_dict = self.model.get_metrics_dict(model_outputs, batch) loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict) From d16e5192aa1210070089a5b915870c44c087c09b Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 1 Sep 2024 04:11:46 -0700 Subject: [PATCH 41/78] cleaned up a lot of code --- .../data/datamanagers/base_datamanager.py | 205 +++++++----------- 1 file changed, 77 insertions(+), 128 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index d77cf211cb..b9ce1c01e9 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -62,6 +62,7 @@ RandIndicesEvalDataloader, ) from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate +from nerfstudio.data.utils.data_utils import identity_collate from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import IterableWrapper, get_orig_class @@ -97,31 +98,6 @@ def variable_res_collate(batch: List[Dict]) -> Dict: return new_batch -def ray_collate(batch: List[RayBundle]): - # start = time.time() - ray_bundle_list, batch_list = list(zip(*batch)) - combined_metadata = {} - if "fisheye_crop_radius" in ray_bundle_list[0].metadata: - combined_metadata["fisheye_crop_radius"] = ray_bundle_list[0].metadata["fisheye_crop_radius"] - if "directions_norm" in ray_bundle_list[0].metadata: - combined_metadata["directions_norm"] = torch.cat([ray_bundle_i.metadata["directions_norm"] for ray_bundle_i in ray_bundle_list], dim=0) - - concatenated_ray_bundle = RayBundle( - origins=torch.cat([ray_bundle_i.origins for ray_bundle_i in ray_bundle_list], dim=0), - directions=torch.cat([ray_bundle_i.directions for ray_bundle_i in ray_bundle_list], dim=0), - pixel_area=torch.cat([ray_bundle_i.pixel_area for ray_bundle_i in ray_bundle_list], dim=0), - camera_indices=torch.cat([ray_bundle_i.camera_indices for ray_bundle_i in ray_bundle_list], dim=0), - metadata=combined_metadata, - ) - concatenated_batch = { - "image" : torch.cat([batch_i["image"] for batch_i in batch_list], dim=0), - "indices": torch.cat([batch_i["indices"] for batch_i in batch_list], dim=0), - } - # end = time.time() - # print((end - start) * 1000) - return [[concatenated_ray_bundle, concatenated_batch]] - - @dataclass class DataManagerConfig(InstantiateConfig): """Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers; @@ -342,9 +318,9 @@ class VanillaDataManagerConfig(DataManagerConfig): """Specifies the dataparser used to unpack the data.""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" - train_num_images_to_sample_from: int = 100 # usually -1 + train_num_images_to_sample_from: int = -1 """Number of images to sample during training iteration.""" - train_num_times_to_repeat_images: int = 10 # usually -1 + train_num_times_to_repeat_images: int = -1 """When not training on all images, number of iterations before picking new images. If -1, never pick new images.""" eval_num_rays_per_batch: int = 1024 @@ -360,20 +336,22 @@ class VanillaDataManagerConfig(DataManagerConfig): """Specifies the collate function to use for the train and eval dataloaders.""" camera_res_scale_factor: float = 1.0 """The scale factor for scaling spatial data such as images, mask, semantics - along with relevant information about camera intrinsics - """ + along with relevant information about camera intrinsics""" patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - prefetch_factor: int = train_num_times_to_repeat_images # prefetch_factor of 16 does well, but any that is equal train_num_times_to_repeat_images is good + use_parallel_dataloader: bool = True + """Allows parallelization of the dataloading process with multiple workers prefetching batches.""" + load_from_disk: bool = False + """If True, conserves RAM memory by loading images from disk. + If False, caches all the images as tensors to RAM.""" + prefetch_factor: int = 4 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - use_ray_train_dataloader: bool = True - """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" - cache_binaries: bool = True - """When enabled, cache raw image files to RAM""" + cache_image_bytes: bool = True + """If True, cache raw image files as byte strings to RAM.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -390,7 +368,15 @@ def __post_init__(self): "\nCameraOptimizerConfig has been moved from the DataManager to the Model.\n", style="bold yellow" ) warnings.warn("above message coming from", FutureWarning, stacklevel=3) - + + """ + These heuristics allow the CPU dataloading bottleneck to equal the GPU bottleneck when training, but can be adjusted + Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck. + """ + if self.load_from_disk: + self.train_num_images_to_sample_from = 50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from + self.train_num_times_to_repeat_images = 10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images + self.prefetch_factor = self.train_num_times_to_repeat_images TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) @@ -408,46 +394,51 @@ def __post_init__(self): from torch.profiler import profile, record_function, ProfilerActivity class RayBatchStream(torch.utils.data.IterableDataset): + """Wrapper around Pytorch's IterableDataset to gerenate the next batch of rays and corresponding labels + with multiple parallel workers. + + Each worker samples a small batch of images, pixel samples those images, and generates rays for one training step. + The same batch of images can be pixel sampled multiple times hasten ray generation, as retrieving images is process + bottlenecked by disk read speed. To avoid Out-Of-Memory (OOM) errors, this batch of images is small and regenerated + by resampling the worker's partition of images to gurantee sampling diversity. + """ def __init__( self, input_dataset: Dataset, datamanager_config: DataManagerConfig, - num_images_to_sample_from: int = -1, # passed in from VanillaDataManager - num_times_to_repeat_images: int = -1, # passed in from VanillaDataManager device: Union[torch.device, str] = "cpu", - collate_fn: Callable[[Any], Any] = nerfstudio_collate, exclude_batch_keys_from_device: Optional[List[str]] = None, num_image_load_threads: int = 4, - cache_all_n_shard_per_worker: bool = True, ): if exclude_batch_keys_from_device is None: exclude_batch_keys_from_device = ["image"] self.input_dataset = input_dataset assert isinstance(self.input_dataset, Sized) - - # self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) - """If True, cache all images to RAM as a collated""" - # self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from - self.num_images_to_sample_from = num_images_to_sample_from - self.num_times_to_repeat_images = num_times_to_repeat_images + + self.datamanager_config = datamanager_config + self.num_images_to_sample_from = self.datamanager_config.train_num_images_to_sample_from + self.num_times_to_repeat_images = self.datamanager_config.train_num_times_to_repeat_images self.device = device - # self.collate_fn = variable_res_collate # variable_res_collate avoids collating images, which is much faster than `nerfstudio_collate` - self.collate_fn = collate_fn - print("collate_fn", self.collate_fn) - print("self.device", self.device) - self.num_image_load_threads = num_image_load_threads # kwargs.get("num_workers", 4) # nb only 4 in defaults + self.collate_fn = variable_res_collate # variable_res_collate avoids np.stack'ing images, which allows it to be much faster than `nerfstudio_collate` + self.num_image_load_threads = num_image_load_threads + """Number of threads created to read images from disk and form collated batches.""" self.exclude_batch_keys_from_device = exclude_batch_keys_from_device + """Which key of the batch (such as 'image', 'mask','depth') to prevent from moving to the device. + For instance, if you would like to conserve GPU memory, don't move the image tensors to the GPU, + which comes at a cost of total training time. The default value is ['image'] + """ # print("self.exclude_batch_keys_from_device", self.exclude_batch_keys_from_device) # usually prints ['image'] - self.datamanager_config = datamanager_config + self.pixel_sampler: PixelSampler = None self.ray_generator: RayGenerator = None + + self.enable_per_worker_image_caching = self.datamanager_config.load_from_disk == False + """If True, each worker's will cache its entire partition of the image dataset as image tensors in RAM.""" self._cached_collated_batch = None """self._cached_collated_batch contains a collated batch of images for a specific worker that's ready for pixel sampling.""" - self.cache_all_n_shard_per_worker = cache_all_n_shard_per_worker - """If True, self._cached_collated_batch is populated with a subset of the dataset assigned to each worker during the iteration process.""" - def _get_pixel_sampler(self, dataset: "TDataset", num_rays_per_batch: int) -> PixelSampler: - """copy-pasta from VanillaDataManager.""" + def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler: + """copy-pasted from VanillaDataManager.""" from nerfstudio.cameras.cameras import CameraType from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSamplerConfig @@ -486,7 +477,7 @@ def _get_batch_list(self, indices=None): # num_threads = int(self.num_ds_load_threads) * 4 num_threads = ( int(self.num_image_load_threads) - if not self.cache_all_n_shard_per_worker + if not self.enable_per_worker_image_caching else 4 * int(self.num_image_load_threads) ) num_threads = min(num_threads, multiprocessing.cpu_count() - 1) @@ -494,69 +485,57 @@ def _get_batch_list(self, indices=None): # NB: this is I/O heavy because we are going to disk and reading an image filename # hence multi-threaded inside the worker - # with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - # for idx in indices: - # res = executor.submit(self.input_dataset.__getitem__, idx) - # results.append(res) - - # # for res in tqdm(results, desc='_get_batch_list'): - # results = tqdm(results) # does not effect times, tested many times - # for res in results: - # batch_list.append(res.result()) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for idx in indices: + res = executor.submit(self.input_dataset.__getitem__, idx) + results.append(res) + results = tqdm(results) # this is temporary + for res in results: + batch_list.append(res.result()) - for idx in tqdm(indices): # this is slower compared to using threads, but using this allows us to profile __getitem__ - batch_list.append(self.input_dataset.__getitem__(idx)) return batch_list def _get_collated_batch(self, indices=None): - """Takes the output of _get_batch_list and collates them with nerfstudio_collate() + """Takes the output of _get_batch_list and collates them with nerfstudio_collate() or variable_res_collate() Note: dict is an instance of collections.abc.Mapping The resulting output is collated_batch: a dictionary with dict_keys(['image_idx', 'image']) collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) """ - with record_function("_get_batch_list"): - batch_list = self._get_batch_list(indices=indices) - # print(self.collate_fn) # prints nerfstudio_collate on mainRGB, but prints variable_res_collate if all3cameras - with record_function("collate_function"): - collated_batch = self.collate_fn(batch_list) - with record_function("sending to GPU"): - collated_batch = get_dict_to_torch( - collated_batch, device=self.device, #exclude=self.exclude_batch_keys_from_device - ) - # batch_list = self._get_batch_list(indices=indices) - # collated_batch = self.collate_fn(batch_list) - # collated_batch = get_dict_to_torch( - # collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device - # ) + batch_list = self._get_batch_list(indices=indices) + collated_batch = self.collate_fn(batch_list) + collated_batch = get_dict_to_torch( + collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device + ) return collated_batch def __iter__(self): - """This implementation has every worker cache the indices of the images they will use to generate rays.""" - dataset_indices = list( - range(len(self.input_dataset)) - ) # this_indices has length = numTrainingImages, at first it is the whole training dataset, but it gets partitioned into equal chunks + """This implementation allows every worker only cache the indices of the images they will use to generate rays to conserve RAM memory.""" worker_info = torch.utils.data.get_worker_info() if worker_info is not None: # if we have multiple processes - per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) + per_worker = int(math.ceil(len(self.input_dataset) / float(worker_info.num_workers))) slice_start = worker_info.id * per_worker else: # we only have a single process per_worker = len(self.input_dataset) slice_start = 0 + dataset_indices = list( + range(len(self.input_dataset)) + ) worker_indices = dataset_indices[ slice_start : slice_start + per_worker ] # the indices of the datapoints in the dataset this worker will load - if self.cache_all_n_shard_per_worker: + if self.enable_per_worker_image_caching: self._cached_collated_batch = self._get_collated_batch(worker_indices) r = random.Random(3301) num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch # default train_num_rays_per_batch is 4096 + # each worker has its own pixel sampler worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) - if self.ray_generator is None: - self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) + self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) + i = 0 while True: - if self.cache_all_n_shard_per_worker: + if self.enable_per_worker_image_caching: collated_batch = self._cached_collated_batch elif i % self.num_times_to_repeat_images == 0: r.shuffle(worker_indices) @@ -565,14 +544,8 @@ def __iter__(self): image_indices = worker_indices else: # get a total of 'num_images_to_sample_from' image indices image_indices = worker_indices[:self.num_images_to_sample_from] - # self._get_collated_batch is slow because it is going to disk to retreive an image many times to create a batch of images. collated_batch = self._get_collated_batch(image_indices) - # with profile(activities=[ProfilerActivity.CPU], record_shapes=True, with_stack=True,) as prof: - # with record_function("process_images"): - # collated_batch = self._get_collated_batch(image_indices) - # with open('_get_batch_list_profile.txt', 'w') as f: - # f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) i += 1 """ Here, the variable 'batch' refers to the output of our pixel sampler. @@ -583,16 +556,13 @@ def __iter__(self): and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol) """ batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. - # the returned batch also somehow moves the images from the CPU to the GPU - # collated_batch["image"].get_device() will return + # the returned batch of pixels also somehow moves the images from the CPU to the GPU + # collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device == True ray_indices = batch["indices"] ray_bundle = self.ray_generator(ray_indices) # the ray_bundle is on the GPU, but batch["image"] is on the CPU yield ray_bundle, batch -def identity(x): - return x - class VanillaDataManager(DataManager, Generic[TDataset]): """Basic stored data manager implementation. @@ -722,20 +692,14 @@ def setup_train(self): assert self.train_dataset is not None CONSOLE.print("Setting up training dataset...") - if self.config.use_ray_train_dataloader: + if self.config.use_parallel_dataloader: import torch.multiprocessing as mp mp.set_start_method("spawn") - self.raybatch_stream = RayBatchStream( input_dataset=self.train_dataset, datamanager_config=self.config, - num_images_to_sample_from=self.config.train_num_images_to_sample_from, - num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - collate_fn=self.config.collate_fn, - cache_all_n_shard_per_worker=False, ) - # This one uses identity collate self.ray_dataloader = torch.utils.data.DataLoader( self.raybatch_stream, batch_size=1, @@ -743,24 +707,10 @@ def setup_train(self): prefetch_factor=self.config.prefetch_factor, shuffle=False, pin_memory=True, - # Our dataset does batching / collation - collate_fn=identity, - pin_memory_device=self.device, # did not actually speed up my implementation + # Our dataset handles batching / collation of rays + collate_fn=identity_collate, + pin_memory_device=self.device, ) - - # # this one uses ray_collate - # self.ray_dataloader = torch.utils.data.DataLoader( - # self.raybatch_stream, - # batch_size=4, - # num_workers=self.config.dataloader_num_workers, - # prefetch_factor=self.config.prefetch_factor, - # shuffle=False, - # pin_memory=True, - # # Our dataset does batching / collation - # collate_fn=ray_collate, - # pin_memory_device=self.device, # did not actually speed up my implementation - # ) - self.iter_train_raybundles = iter(self.ray_dataloader) else: self.iter_train_raybundles = None @@ -815,11 +765,10 @@ def setup_eval(self): def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" self.train_count += 1 - if self.config.use_ray_train_dataloader: + if self.config.use_parallel_dataloader: ret = next(self.iter_train_raybundles) assert len(ret) == 1, f"batch size should be one {len(ret)}" ray_bundle, batch = ret[0] - # ray_bundle = RayBundle.from_dict(ray_bundle_dict) ray_bundle = ray_bundle.to(self.device) else: image_batch = next(self.iter_train_image_dataloader) From 367d512615838ce00b2ce2caf0dfd4f57846ecc9 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 1 Sep 2024 04:19:58 -0700 Subject: [PATCH 42/78] process_project_aria back to main branch and some cleanup in full_image_datamanager --- .../datamanagers/full_images_datamanager.py | 6 +- .../scripts/datasets/process_project_aria.py | 355 ++++-------------- 2 files changed, 65 insertions(+), 296 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 2b88ba629b..c20661144e 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -574,9 +574,7 @@ def undistort_idx(idx: int, dataset: TDataset, image_type: Literal["uint8", "flo camera = dataset.cameras[idx].reshape(()) # dataset.cameras.width[idx] = data["image"].shape[1] # dataset.cameras.height[idx] = data["image"].shape[0] - if idx == 48: - # breakpoint() - print("beginning", camera.width, camera.height) + assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' @@ -599,8 +597,6 @@ def undistort_idx(idx: int, dataset: TDataset, image_type: Literal["uint8", "flo # dataset.cameras.width[idx] = image.shape[1] # dataset.cameras.height[idx] = image.shape[0] # dataset.cameras.distortion_params = None - if idx == 48: - print("ending", camera.width, camera.height) return data def undistort_view(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: diff --git a/nerfstudio/scripts/datasets/process_project_aria.py b/nerfstudio/scripts/datasets/process_project_aria.py index 7947d3c20c..fe48748325 100644 --- a/nerfstudio/scripts/datasets/process_project_aria.py +++ b/nerfstudio/scripts/datasets/process_project_aria.py @@ -13,13 +13,11 @@ # limitations under the License. import json -import random import sys import threading from dataclasses import dataclass -from itertools import zip_longest from pathlib import Path -from typing import Any, Dict, List, Literal, Tuple, cast +from typing import Any, Dict, List, cast import numpy as np import open3d as o3d @@ -27,9 +25,8 @@ from PIL import Image try: - from projectaria_tools.core import calibration, mps + from projectaria_tools.core import mps from projectaria_tools.core.data_provider import VrsDataProvider, create_vrs_data_provider - from projectaria_tools.core.image import InterpolationMethod from projectaria_tools.core.mps.utils import filter_points_from_confidence from projectaria_tools.core.sophus import SE3 except ImportError: @@ -71,7 +68,6 @@ class AriaImageFrame: file_path: str t_world_camera: SE3 timestamp_ns: float - pinhole_intrinsic: Tuple[float, float, float, float] @dataclass @@ -80,12 +76,11 @@ class TimedPoses: t_world_devices: List[SE3] -def get_camera_calibs( - provider: VrsDataProvider, name: Literal["camera-rgb", "camera-slam-left", "camera-slam-right"] = "camera-rgb" -) -> AriaCameraCalibration: +def get_camera_calibs(provider: VrsDataProvider) -> Dict[str, AriaCameraCalibration]: """Retrieve the per-camera factory calibration from within the VRS.""" - assert name in ["camera-rgb", "camera-slam-left", "camera-slam-right"], f"{name} is not a valid camera sensor" + factory_calib = {} + name = "camera-rgb" device_calib = provider.get_device_calibration() assert device_calib is not None, "Could not find device calibration" sensor_calib = device_calib.get_camera_calib(name) @@ -106,7 +101,7 @@ def get_camera_calibs( t_device_camera=sensor_calib.get_transform_device_camera(), ) - return factory_calib[name] + return factory_calib def read_trajectory_csv_to_dict(file_iterable_csv: str) -> TimedPoses: @@ -123,101 +118,25 @@ def read_trajectory_csv_to_dict(file_iterable_csv: str) -> TimedPoses: ) -def undistort_image_and_calibration( - input_image: np.ndarray, - input_calib: calibration.CameraCalibration, - output_focal_length: int, -) -> Tuple[np.ndarray, calibration.CameraCalibration]: - """ - Return the undistorted image and the updated camera calibration. - """ - input_calib_width = input_calib.get_image_size()[0] - input_calib_height = input_calib.get_image_size()[1] - if input_image.shape[1] != input_calib_width or input_image.shape[0] != input_calib_height: - raise ValueError( - f"Input image shape {input_image.shape} does not match calibration {input_calib.get_image_size()}" - ) - - # Undistort the image - pinhole_calib = calibration.get_linear_camera_calibration( - int(input_calib_width), - int(input_calib_height), - output_focal_length, - "pinhole", - input_calib.get_transform_device_camera(), - ) - output_image = calibration.distort_by_calibration( - input_image, pinhole_calib, input_calib, InterpolationMethod.BILINEAR - ) - - return output_image, pinhole_calib - - -def rotate_upright_image_and_calibration( - input_image: np.ndarray, - input_calib: calibration.CameraCalibration, -) -> Tuple[np.ndarray, calibration.CameraCalibration]: - """ - Return the rotated upright image and update both intrinsics and extrinsics of the camera calibration - NOTE: This function only supports pinhole and fisheye624 camera model. - """ - output_image = np.rot90(input_image, k=3) - updated_calib = calibration.rotate_camera_calib_cw90deg(input_calib) - - return output_image, updated_calib - - -def generate_circular_mask(numRows: int, numCols: int, radius: float) -> np.ndarray: - """ - Generates a mask where a circle in the center of the image with input radius is white (sampled from). - Everything outside the circle is black (masked out) - """ - # Calculate the center coordinates - rows, cols = np.ogrid[:numRows, :numCols] - center_row, center_col = numRows // 2, numCols // 2 - - # Calculate the distance of each pixel from the center - distance_from_center = np.sqrt((rows - center_row) ** 2 + (cols - center_col) ** 2) - mask = np.zeros((numRows, numCols), dtype=np.uint8) - mask[distance_from_center <= radius] = 1 - return mask - - def to_aria_image_frame( provider: VrsDataProvider, index: int, name_to_camera: Dict[str, AriaCameraCalibration], t_world_devices: TimedPoses, output_dir: Path, - camera_name: str = "camera-rgb", - pinhole: bool = False, ) -> AriaImageFrame: - aria_cam_calib = name_to_camera[camera_name] - stream_id = provider.get_stream_id_from_label(camera_name) - assert stream_id is not None, f"Could not find stream {camera_name}" + name = "camera-rgb" - # Retrieve the current camera calibration - device_calib = provider.get_device_calibration() - assert device_calib is not None, "Could not find device calibration" - src_calib = device_calib.get_camera_calib(camera_name) - assert isinstance(src_calib, calibration.CameraCalibration), "src_calib is not of type CameraCalibration" + camera_calibration = name_to_camera[name] + stream_id = provider.get_stream_id_from_label(name) + assert stream_id is not None, f"Could not find stream {name}" - # Get the image corresponding to this index and undistort it + # Get the image corresponding to this index image_data = provider.get_image_data_by_index(stream_id, index) - image_array, intrinsic = image_data[0].to_numpy_array().astype(np.uint8), (0, 0, 0, 0) - if pinhole: - f_length = 500 if camera_name == "camera-rgb" else 170 - image_array, src_calib = undistort_image_and_calibration(image_array, src_calib, f_length) - intrinsic = (f_length, f_length, image_array.shape[1] // 2, image_array.shape[0] // 2) - - # Rotate the image right side up - image_array, src_calib = rotate_upright_image_and_calibration(image_array, src_calib) - img = Image.fromarray(image_array) + img = Image.fromarray(image_data[0].to_numpy_array()) capture_time_ns = image_data[1].capture_timestamp_ns - intrinsic = (intrinsic[0], intrinsic[1], intrinsic[3], intrinsic[2]) - # Save the image - file_path = f"{output_dir}/{camera_name}_{capture_time_ns}.jpg" + file_path = f"{output_dir}/{name}_{capture_time_ns}.jpg" threading.Thread(target=lambda: img.save(file_path)).start() # Find the nearest neighbor pose with the closest timestamp to the capture time. @@ -227,46 +146,17 @@ def to_aria_image_frame( t_world_device = t_world_devices.t_world_devices[nearest_pose_idx] # Compute the world to camera transform. - t_world_camera = t_world_device @ src_calib.get_transform_device_camera() @ T_ARIA_NERFSTUDIO - - # Define new AriaCameraCalibration since we rotated the image to be upright - width = src_calib.get_image_size()[0].item() - height = src_calib.get_image_size()[1].item() - intrinsics = src_calib.projection_params() - aria_cam_calib = AriaCameraCalibration( - fx=intrinsics[0], - fy=intrinsics[0], - cx=intrinsics[1], - cy=intrinsics[2], - distortion_params=intrinsics[3:15], - width=width, - height=height, - t_device_camera=src_calib.get_transform_device_camera(), - ) + t_world_camera = t_world_device @ camera_calibration.t_device_camera @ T_ARIA_NERFSTUDIO return AriaImageFrame( - camera=aria_cam_calib, + camera=camera_calibration, file_path=file_path, t_world_camera=t_world_camera, timestamp_ns=capture_time_ns, - pinhole_intrinsic=intrinsic, ) -def to_nerfstudio_frame(frame: AriaImageFrame, pinhole: bool = False, mask_path: str = "") -> Dict: - if pinhole: - return { - "fl_x": frame.pinhole_intrinsic[0], - "fl_y": frame.pinhole_intrinsic[1], - "cx": frame.pinhole_intrinsic[2], - "cy": frame.pinhole_intrinsic[3], - "w": frame.pinhole_intrinsic[2] * 2, - "h": frame.pinhole_intrinsic[3] * 2, - "file_path": frame.file_path, - "transform_matrix": frame.t_world_camera.to_matrix().tolist(), - "timestamp": frame.timestamp_ns, - "mask_path": mask_path, - } +def to_nerfstudio_frame(frame: AriaImageFrame) -> Dict: return { "fl_x": frame.camera.fx, "fl_y": frame.camera.fy, @@ -288,187 +178,70 @@ class ProcessProjectAria: https://facebookresearch.github.io/projectaria_tools/docs/ARK/mps. """ - vrs_file: Tuple[Path, ...] - """Path to the VRS file(s).""" - mps_data_dir: Tuple[Path, ...] + vrs_file: Path + """Path to the VRS file.""" + mps_data_dir: Path """Path to Project Aria Machine Perception Services (MPS) attachments.""" output_dir: Path """Path to the output directory.""" - points_file: Tuple[Path, ...] = () - """Path to the point cloud file (usually called semidense_points.csv.gz) if not in the mps_data_dir""" - include_side_cameras: bool = False - """If True, include and process the images captured by the grayscale side cameras. - If False, only uses the main RGB camera's data.""" - max_dataset_size: int = -1 - """Max number of images to train on. If the provided vrs_file has more images than max_dataset_size, - images will be sampled approximately evenly. If max_dataset_size=-1, use all images available.""" def main(self) -> None: """Generate a nerfstudio dataset from ProjectAria data (VRS) and MPS attachments.""" - # Create output directory if it doesn't exist + # Create output directory if it doesn't exist. self.output_dir = self.output_dir.absolute() self.output_dir.mkdir(parents=True, exist_ok=True) - # Create list of tuples containing files from each wearer and output variables - assert len(self.vrs_file) == len( - self.mps_data_dir - ), "Please provide an Aria MPS attachment for each corresponding VRS file." - vrs_mps_points_triplets = list(zip_longest(self.vrs_file, self.mps_data_dir, self.points_file)) # type: ignore - num_recordings = len(vrs_mps_points_triplets) + provider = create_vrs_data_provider(str(self.vrs_file.absolute())) + assert provider is not None, "Cannot open file" + + name_to_camera = get_camera_calibs(provider) + + print("Getting poses from closed loop trajectory CSV...") + trajectory_csv = self.mps_data_dir / "closed_loop_trajectory.csv" + t_world_devices = read_trajectory_csv_to_dict(str(trajectory_csv.absolute())) + + name = "camera-rgb" + stream_id = provider.get_stream_id_from_label(name) + + # create an AriaImageFrame for each image in the VRS. + print("Creating Aria frames...") + aria_frames = [ + to_aria_image_frame(provider, index, name_to_camera, t_world_devices, self.output_dir) + for index in range(0, provider.get_num_data(stream_id)) + ] + + # create the NerfStudio frames from the AriaImageFrames. + print("Creating NerfStudio frames...") + CANONICAL_RGB_VALID_RADIUS = 707.5 + CANONICAL_RGB_WIDTH = 1408 + rgb_valid_radius = CANONICAL_RGB_VALID_RADIUS * (aria_frames[0].camera.width / CANONICAL_RGB_WIDTH) nerfstudio_frames = { - "camera_model": "OPENCV" if self.include_side_cameras else ARIA_CAMERA_MODEL, - "frames": [], + "camera_model": ARIA_CAMERA_MODEL, + "frames": [to_nerfstudio_frame(frame) for frame in aria_frames], + "fisheye_crop_radius": rgb_valid_radius, } - points = [] - names = ["camera-rgb", "camera-slam-left", "camera-slam-right"] - total_num_rgb_images_per_recording_list = [] - total_num_images_per_recording_list = [] - - # Count the total number of images per dataset - for rec_i, (vrs_file, mps_data_dir, points_file) in enumerate(vrs_mps_points_triplets): - provider = create_vrs_data_provider(str(vrs_file.absolute())) - assert provider is not None, "Cannot open file" - stream_ids = [provider.get_stream_id_from_label(name) for name in names] - total_num_rgb_images_per_recording_list.append(provider.get_num_data(stream_ids[0])) - total_num_images_per_recording_list.append( - sum([provider.get_num_data(stream_id) for stream_id in stream_ids]) - ) - if not self.include_side_cameras: - assert self.max_dataset_size <= sum( - total_num_rgb_images_per_recording_list - ), "Specify a dataset size at most the number of RGB images provided" - else: - assert self.max_dataset_size <= sum( - total_num_images_per_recording_list - ), "Specify a dataset size at most the number of images provided" - - # Process the aria data of each user one by one - for rec_i, (vrs_file, mps_data_dir, points_file) in enumerate(vrs_mps_points_triplets): - provider = create_vrs_data_provider(str(vrs_file.absolute())) - assert provider is not None, "Cannot open file" - - name_to_camera = { - name: get_camera_calibs(provider, name) # type: ignore - for name in names - } # name_to_camera is of type Dict[str, AriaCameraCalibration] - - print(f"Getting poses from recording {rec_i + 1}'s closed loop trajectory CSV...") - trajectory_csv = mps_data_dir / "closed_loop_trajectory.csv" - t_world_devices = read_trajectory_csv_to_dict(str(trajectory_csv.absolute())) - - stream_ids = [provider.get_stream_id_from_label(name) for name in names] - - # Create an AriaImageFrame for each image in the VRS - print(f"Creating Aria frames for recording {rec_i + 1}...") - CANONICAL_RGB_VALID_RADIUS = 707.5 # radius of a circular mask that represents the valid area on the camera's sensor plane. Pixels out of this circular region are considered invalid - CANONICAL_RGB_WIDTH = 1408 - - if not self.include_side_cameras: # RGB images only - if self.max_dataset_size == -1: - sampling_indices = range(provider.get_num_data(stream_ids[0])) - else: - num_images_to_sample = ( - self.max_dataset_size * total_num_rgb_images_per_recording_list[rec_i] - ) // sum(total_num_rgb_images_per_recording_list) - sampling_indices = random.sample(range(provider.get_num_data(stream_ids[0])), num_images_to_sample) - aria_rgb_frames = [ - to_aria_image_frame( - provider, index, name_to_camera, t_world_devices, self.output_dir, camera_name=names[0] - ) - for index in sampling_indices - ] - print(f"Creating NerfStudio frames for recording {rec_i + 1}...") - nerfstudio_frames["frames"] += [to_nerfstudio_frame(frame) for frame in aria_rgb_frames] - rgb_valid_radius = CANONICAL_RGB_VALID_RADIUS * ( - aria_rgb_frames[0].camera.width / CANONICAL_RGB_WIDTH - ) # to handle both high-res 2880 x 2880 aria captures - nerfstudio_frames["fisheye_crop_radius"] = rgb_valid_radius - else: # include the side grayscale cameras - total_num_images_per_camera_list = [provider.get_num_data(stream_id) for stream_id in stream_ids] - if self.max_dataset_size == -1: - sampling_indices_list = [range(num_images) for num_images in total_num_images_per_camera_list] - else: - total_num_images = sum( - total_num_images_per_camera_list - ) # total number of images for this recording - num_images_to_sample = ( - self.max_dataset_size // num_recordings - ) # total number of images to sample for this recording - num_images_to_sample_per_camera_list = [ - num_images_to_sample * num // total_num_images for num in total_num_images_per_camera_list - ] - sampling_indices_list = [ - random.sample( - range(total_num_images_per_camera_list[i]), num_images_to_sample_per_camera_list[i] - ) - for i in range(3) - ] - aria_all3cameras_pinhole_frames = [ - [ - to_aria_image_frame( - provider, - index, - name_to_camera, - t_world_devices, - self.output_dir, - camera_name=names[i], - pinhole=True, - ) - for index in sampling_indices_list[i] - ] - for i, stream_id in enumerate(stream_ids) - ] - # Generate masks for undistorted images - rgb_width = aria_all3cameras_pinhole_frames[0][0].camera.width - rgb_valid_radius = CANONICAL_RGB_VALID_RADIUS * (rgb_width / CANONICAL_RGB_WIDTH) - slam_valid_radius = 330.0 # found here: https://github.com/facebookresearch/projectaria_tools/blob/4aee633cb667ab927825dc10477cad0df8393a34/core/calibration/loader/SensorCalibrationJson.cpp#L102C5-L104C18 - rgb_mask_nparray, slam_mask_nparray = ( - generate_circular_mask(rgb_width, rgb_width, rgb_valid_radius), - generate_circular_mask(640, 480, slam_valid_radius), - ) - rgb_mask_filepath, slam_mask_filepath = ( - f"{self.output_dir}/rgb_mask.jpg", - f"{self.output_dir}/slam_mask.jpg", - ) - Image.fromarray(rgb_mask_nparray).save(rgb_mask_filepath) - Image.fromarray(slam_mask_nparray).save(slam_mask_filepath) - - print(f"Creating NerfStudio frames for recording {rec_i + 1}...") - mask_filepaths = [rgb_mask_filepath, slam_mask_filepath, slam_mask_filepath] - pinhole_frames = [ - to_nerfstudio_frame(frame, pinhole=True, mask_path=mask_filepath) - for i, mask_filepath in enumerate(mask_filepaths) - for frame in aria_all3cameras_pinhole_frames[i] - ] - nerfstudio_frames["frames"] += pinhole_frames - - if points_file is not None: - points_path = points_file - else: - points_path = mps_data_dir / "global_points.csv.gz" - if not points_path.exists(): - # MPS point cloud output was renamed in Aria's December 4th, 2023 update. - # https://facebookresearch.github.io/projectaria_tools/docs/ARK/sw_release_notes#project-aria-updates-aria-mobile-app-v140-and-changes-to-mps - points_path = mps_data_dir / "semidense_points.csv.gz" - - if points_path.exists(): - print(f"Found global points for recording {rec_i+1}") - points_data = mps.read_global_point_cloud(str(points_path)) # type: ignore - points_data = filter_points_from_confidence(points_data) - points += [cast(Any, it).position_world for it in points_data] - - if len(points) > 0: - print("Saving found points to PLY...") - print(f"Total number of points found: {len(points)} in {num_recordings} recording(s) provided") + + # save global point cloud, which is useful for Gaussian Splatting. + points_path = self.mps_data_dir / "global_points.csv.gz" + if not points_path.exists(): + # MPS point cloud output was renamed in Aria's December 4th, 2023 update. + # https://facebookresearch.github.io/projectaria_tools/docs/ARK/sw_release_notes#project-aria-updates-aria-mobile-app-v140-and-changes-to-mps + points_path = self.mps_data_dir / "semidense_points.csv.gz" + + if points_path.exists(): + print("Found global points, saving to PLY...") + points_data = mps.read_global_point_cloud(str(points_path)) # type: ignore + points_data = filter_points_from_confidence(points_data) pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(np.array(points)) + pcd.points = o3d.utility.Vector3dVector(np.array([cast(Any, it).position_world for it in points_data])) ply_file_path = self.output_dir / "global_points.ply" o3d.io.write_point_cloud(str(ply_file_path), pcd) + nerfstudio_frames["ply_file_path"] = "global_points.ply" else: print("No global points found!") - print(len(nerfstudio_frames["frames"])) - # Write the json out to disk as transforms.json + + # write the json out to disk as transforms.json print("Writing transforms.json") transform_file = self.output_dir / "transforms.json" with open(transform_file, "w", encoding="UTF-8"): @@ -477,4 +250,4 @@ def main(self) -> None: if __name__ == "__main__": tyro.extras.set_accent_color("bright_yellow") - tyro.cli(ProcessProjectAria).main() \ No newline at end of file + tyro.cli(ProcessProjectAria).main() From d3d99b442e046dacb14559151d5e9f4a4b3ba801 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 1 Sep 2024 04:43:19 -0700 Subject: [PATCH 43/78] clarifying docstrings --- nerfstudio/data/utils/dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 8ef6df1353..30a725aa51 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -44,7 +44,7 @@ class CacheDataloader(DataLoader): Args: dataset: Dataset to sample from. num_samples_to_collate: How many images to sample rays for each batch. -1 for all images. - num_times_to_repeat_images: How many ray bundles to . -1 to never pick new images. + num_times_to_repeat_images: How often to yield an image batch before resampling. -1 to never pick new images. device: Device to perform computation. collate_fn: The function we will use to collate our training data """ From 6f763dc7c255497f5ab940f61dbb0d6ca7fcdcac Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Tue, 3 Sep 2024 02:22:00 -0700 Subject: [PATCH 44/78] further PR cleanup --- .../datamanagers/full_images_datamanager.py | 43 ++----------------- nerfstudio/data/utils/dataloaders.py | 29 ------------- nerfstudio/data/utils/nerfstudio_collate.py | 16 +++---- nerfstudio/models/splatfacto.py | 2 +- 4 files changed, 11 insertions(+), 79 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index c20661144e..bcaf108aeb 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -393,7 +393,6 @@ def _undistort_image( "We don't support the 4th Brown parameter for image undistortion, " "Only k1, k2, k3, p1, p2 can be non-zero." ) - #print(distortion_params) # [ 0.05517609 -0.07427584 0. 0. -0.00026702 -0.00060216] # we rearrange the distortion parameters because OpenCV expects the order (k1, k2, p1, p2, k3) # see https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html distortion_params = np.array( @@ -421,11 +420,9 @@ def _undistort_image( roi = 0, 0, image.shape[1], image.shape[0] # crop the image and update the intrinsics accordingly x, y, w, h = roi - # print(x, y, w, h) # prints 0, 0, 539, 959 image = image[y : y + h, x : x + w] newK[0, 2] -= x newK[1, 2] -= y - # print("2:", image.shape) # prints (959, 539, 3) if "depth_image" in data: data["depth_image"] = data["depth_image"][y : y + h, x : x + w] @@ -568,37 +565,6 @@ def _undistort_image( ## Let's implement a parallelized splat dataloader! -def undistort_idx(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: - """Undistorts an image to one taken by a linear (pinhole) camera model and updates the dataset's camera intrinsics to a linear camera model""" - data = dataset.get_data(idx, image_type) - camera = dataset.cameras[idx].reshape(()) - # dataset.cameras.width[idx] = data["image"].shape[1] - # dataset.cameras.height[idx] = data["image"].shape[0] - - assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( - f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' - f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' - ) - if camera.distortion_params is None or torch.all(camera.distortion_params == 0): - return data - K = camera.get_intrinsics_matrices().numpy() - distortion_params = camera.distortion_params.numpy() - image = data["image"].numpy() - K, image, mask = _undistort_image(camera, distortion_params, data, image, K) - # print(image.shape[1]) # outputs 539 - # print(cameras[48].reshape(()).width.item()) # outputs 540 - data["image"] = torch.from_numpy(image) - if mask is not None: - data["mask"] = mask - # dataset.cameras.fx[idx] = float(K[0, 0]) - # dataset.cameras.fy[idx] = float(K[1, 1]) - # dataset.cameras.cx[idx] = float(K[0, 2]) - # dataset.cameras.cy[idx] = float(K[1, 2]) - # dataset.cameras.width[idx] = image.shape[1] - # dataset.cameras.height[idx] = image.shape[0] - # dataset.cameras.distortion_params = None - return data - def undistort_view(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: """Undistorts an image to one taken by a linear (pinhole) camera model and returns a new Camera with these updated intrinsics Note: this method does not modify the dataset's attributes at all. @@ -621,7 +587,7 @@ def undistort_view(idx: int, dataset: TDataset, image_type: Literal["uint8", "fl if mask is not None: data["mask"] = mask - # create a new Camera + # create a new Camera with the rectified / undistorted intrinsics new_camera = Cameras( camera_to_worlds=camera.camera_to_worlds.unsqueeze(0), fx=torch.Tensor([[float(K[0, 0])]]), @@ -675,13 +641,10 @@ def __iter__(self): i = 0 idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve data, camera = undistort_view(idx, self.input_dataset, self.config.cache_images_type) - camera2 = self.input_dataset.cameras[idx : idx + 1]#.to(self.device) if camera.metadata is None: camera.metadata = {} camera.metadata["cam_idx"] = idx i += 1 - if torch.sum(camera.camera_to_worlds) == 0: - print(i, camera.camera_to_worlds, "YOYO INSIDE IMAGEBATCHSTREAM") yield camera, data class ParallelFullImageDatamanager(FullImageDatamanager, Generic[TDataset]): @@ -716,7 +679,7 @@ def setup_train(self): batch_size=1, num_workers=self.config.dataloader_num_workers, collate_fn=identity_collate, - pin_memory_device=self.device, + # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? ) self.iter_train_image_dataloader = iter(self.train_image_dataloader) @@ -731,7 +694,7 @@ def setup_eval(self): batch_size=1, num_workers=self.config.dataloader_num_workers, collate_fn=identity_collate, - pin_memory_device=self.device, + # pin_memory_device=self.device, ) self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) # these things output tuples diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 30a725aa51..427792ca86 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -32,7 +32,6 @@ from nerfstudio.cameras.rays import RayBundle from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate -from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import get_dict_to_torch from nerfstudio.utils.rich_utils import CONSOLE @@ -134,7 +133,6 @@ def __iter__(self): collated_batch = self.cached_collated_batch elif self.first_time or ( self.num_times_to_repeat_images != -1 and self.num_repeated >= self.num_times_to_repeat_images - # if it's the first time, we need to ): # trigger a reset self.num_repeated = 0 @@ -148,33 +146,6 @@ def __iter__(self): yield collated_batch -import torch -class ParallelCacheDataloader(torch.utils.data.IterableDataset): - """Creates batches of the InputDataset return type with multiple workers, can be toggled to return image batches or RayBundles - When return image batches - """ - def __init__( - self, - input_dataset: Dataset, - num_images_to_sample_from: int = -1, - device: Union[torch.device, str] = "cpu", - collate_fn: Callable[[Any], Any] = nerfstudio_collate, - exclude_batch_keys_from_device: Optional[List[str]] = None, - num_image_load_threads : int = 2, - cache_all_n_shard_per_worker : bool = True, - ): - if exclude_batch_keys_from_device is None: - exclude_batch_keys_from_device = ["image"] - self.input_dataset = input_dataset - assert isinstance(self.input_dataset, Sized) - - self.num_images_to_sample_from = num_images_to_sample_from - """The size of a collated_batch of images""" - - def __iter__(self): - pass - - class EvalDataloader(DataLoader): """Evaluation dataloader base class diff --git a/nerfstudio/data/utils/nerfstudio_collate.py b/nerfstudio/data/utils/nerfstudio_collate.py index f05ea80882..b5b391f543 100644 --- a/nerfstudio/data/utils/nerfstudio_collate.py +++ b/nerfstudio/data/utils/nerfstudio_collate.py @@ -25,7 +25,6 @@ import torch.utils.data from nerfstudio.cameras.cameras import Cameras -from torch.profiler import profile, record_function NERFSTUDIO_COLLATE_ERR_MSG_FORMAT = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts, lists or anything in {}; found {}" @@ -95,15 +94,14 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None - with record_function("creating shared memory"): - if torch.utils.data.get_worker_info() is not None: - # If we're in a background process, concatenate directly into a - # shared memory tensor to avoid an extra copy - numel = sum(x.numel() for x in batch) - storage = elem.storage()._new_shared(numel, device=elem.device) - out = elem.new(storage).resize_(len(batch), *list(elem.size())) + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem.storage()._new_shared(numel, device=elem.device) + out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) - elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": + elif elem_type.__module__ == "numpy" and elem_type.__name__ not in ("str_", "string_"): if elem_type.__name__ in ("ndarray", "memmap"): # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index e5c30fe2f4..702bdc5073 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -724,7 +724,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: render_mode = "RGB+ED" else: render_mode = "RGB" - # breakpoint() + if self.config.sh_degree > 0: sh_degree_to_use = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) else: From a5191bd3f4625a937122de56b4db7c491b0dc280 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 9 Sep 2024 02:53:26 -0700 Subject: [PATCH 45/78] updating models --- nerfstudio/configs/method_configs.py | 106 +++++++++++++++++++++++++-- nerfstudio/models/nerfacto.py | 19 +---- 2 files changed, 100 insertions(+), 25 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 53c5e86a63..93896dc5fb 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -27,7 +27,7 @@ from nerfstudio.configs.base_config import ViewerConfig from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig -from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig, ParallelFullImageDatamanager +from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig @@ -212,6 +212,48 @@ vis="viewer", ) +method_configs["nerfacto-oom"] = TrainerConfig( + method_name="nerfacto-oom", + steps_per_eval_batch=500, + steps_per_save=2000, + max_num_iterations=30000, + mixed_precision=True, + pipeline=VanillaPipelineConfig( + datamanager=VanillaDataManagerConfig( + dataparser=NerfstudioDataParserConfig(), + train_num_rays_per_batch=4096, + eval_num_rays_per_batch=4096, + use_parallel_dataloader=True, + load_from_disk=True, + dataloader_num_workers=4, + prefetch_factor=10, + train_num_images_to_sample_from=50, + train_num_times_to_repeat_images=10, + ), + model=NerfactoModelConfig( + eval_num_rays_per_chunk=1 << 15, + average_init_density=0.01, + camera_optimizer=CameraOptimizerConfig(mode="SO3xR3"), + ), + ), + optimizers={ + "proposal_networks": { + "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000), + }, + "fields": { + "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000), + }, + "camera_opt": { + "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=5000), + }, + }, + viewer=ViewerConfig(num_rays_per_chunk=1 << 15), + vis="viewer", +) + method_configs["depth-nerfacto"] = TrainerConfig( method_name="depth-nerfacto", steps_per_eval_batch=500, @@ -602,10 +644,11 @@ mixed_precision=False, pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( - _target=ParallelFullImageDatamanager[InputDataset], + # _target=ParallelFullImageDatamanager[InputDataset], dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), # dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", + use_parallel_dataloader=True, ), model=SplatfactoModelConfig(), ), @@ -654,12 +697,7 @@ max_num_iterations=30000, mixed_precision=False, pipeline=VanillaPipelineConfig( - # datamanager=FullImageDatamanagerConfig( - # dataparser=NerfstudioDataParserConfig(load_3D_points=True), - # cache_images_type="uint8", - # ), datamanager=FullImageDatamanagerConfig( - _target=ParallelFullImageDatamanager[InputDataset], dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", ), @@ -705,6 +743,60 @@ vis="viewer", ) +method_configs["splatfacto-oom"] = TrainerConfig( + method_name="splatfacto-oom", + steps_per_eval_image=100, + steps_per_eval_batch=0, + steps_per_save=2000, + steps_per_eval_all_images=1000, + max_num_iterations=30000, + mixed_precision=False, + pipeline=VanillaPipelineConfig( + datamanager=FullImageDatamanagerConfig( + # _target=ParallelFullImageDatamanager[InputDataset], + dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), + # dataparser=NerfstudioDataParserConfig(load_3D_points=True), + cache_images_type="uint8", + use_parallel_dataloader=True, + ), + model=SplatfactoModelConfig(), + ), + optimizers={ + "means": { + "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig( + lr_final=1.6e-6, + max_steps=30000, + ), + }, + "features_dc": { + "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), + "scheduler": None, + }, + "features_rest": { + "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), + "scheduler": None, + }, + "opacities": { + "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), + "scheduler": None, + }, + "scales": { + "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), + "scheduler": None, + }, + "quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None}, + "camera_opt": { + "optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig( + lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 + ), + }, + }, + viewer=ViewerConfig(num_rays_per_chunk=1 << 15), + vis="viewer", +) + def merge_methods(methods, method_descriptions, new_methods, new_descriptions, overwrite=True): """Merge new methods and descriptions into existing methods and descriptions. diff --git a/nerfstudio/models/nerfacto.py b/nerfstudio/models/nerfacto.py index 71e60fbb2c..bfccfd8797 100644 --- a/nerfstudio/models/nerfacto.py +++ b/nerfstudio/models/nerfacto.py @@ -46,7 +46,7 @@ from nerfstudio.model_components.shaders import NormalsShader from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils import colormaps -from torch.profiler import profile, record_function, ProfilerActivity + @dataclass class NerfactoModelConfig(ModelConfig): @@ -363,23 +363,6 @@ def get_metrics_dict(self, outputs, batch): def get_loss_dict(self, outputs, batch, metrics_dict=None): loss_dict = {} image = batch["image"].to(self.device) - # Start profiling - if image.dtype == torch.uint8: - image = image / torch.tensor(255, dtype=torch.float32, device=self.device) - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - # record_shapes=True, - # profile_memory=True, - # with_stack=True) as prof: - # with record_function("image_normalization"): - # image = image / torch.tensor(255, dtype=torch.float32, device=self.device) - # # image = image.float() / 255.0 - - # # Write profiler results to a file - # profile_path = "profiler_results.txt" - # with open(profile_path, "w") as f: - # f.write(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - image = image.to(self.device) pred_rgb, gt_rgb = self.renderer_rgb.blend_background_for_loss_computation( pred_image=outputs["rgb"], pred_accumulation=outputs["accumulation"], From 7db70dca0f088177c4282624d5e4794f29675e3d Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 9 Sep 2024 02:59:56 -0700 Subject: [PATCH 46/78] further cleanup --- .../data/datamanagers/base_datamanager.py | 20 +- .../datamanagers/full_images_datamanager.py | 186 +++++++++++------- 2 files changed, 121 insertions(+), 85 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index b9ce1c01e9..8c5d828f39 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -316,6 +316,8 @@ class VanillaDataManagerConfig(DataManagerConfig): """Target class to instantiate.""" dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig) """Specifies the dataparser used to unpack the data.""" + cache_images_type: Literal["uint8", "float32"] = "float32" + """The image type returned from manager, caching images in uint8 saves memory""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" train_num_images_to_sample_from: int = -1 @@ -339,18 +341,18 @@ class VanillaDataManagerConfig(DataManagerConfig): along with relevant information about camera intrinsics""" patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - use_parallel_dataloader: bool = True - """Allows parallelization of the dataloading process with multiple workers prefetching batches.""" + use_parallel_dataloader: bool = False + """Allows parallelization of the dataloading process with multiple workers prefetching RayBundles.""" load_from_disk: bool = False """If True, conserves RAM memory by loading images from disk. - If False, caches all the images as tensors to RAM.""" - prefetch_factor: int = 4 - """The limit number of batches a worker will start loading once an iterator is created. - More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" - dataloader_num_workers: int = 4 + If False, caches all the images as tensors to RAM and loads from RAM.""" + dataloader_num_workers: int = 0 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - cache_image_bytes: bool = True + prefetch_factor: int = None + """The limit number of batches a worker will start loading once an iterator is created. + More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" + cache_image_bytes: bool = False """If True, cache raw image files as byte strings to RAM.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. @@ -706,7 +708,7 @@ def setup_train(self): num_workers=self.config.dataloader_num_workers, prefetch_factor=self.config.prefetch_factor, shuffle=False, - pin_memory=True, + # pin_memory=True, # Our dataset handles batching / collation of rays collate_fn=identity_collate, pin_memory_device=self.device, diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index bcaf108aeb..751acfadbf 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -57,13 +57,6 @@ class FullImageDatamanagerConfig(DataManagerConfig): """The scale factor for scaling spatial data such as images, mask, semantics along with relevant information about camera intrinsics """ - eval_num_images_to_sample_from: int = -1 - """Number of images to sample during eval iteration.""" - eval_num_times_to_repeat_images: int = -1 - """When not evaluating on all images, number of iterations before picking - new images. If -1, never pick new images.""" - eval_image_indices: Optional[Tuple[int, ...]] = (0,) - """Specifies the image indices to use during eval; if None, uses all.""" cache_images: Literal["cpu", "gpu", "disk"] = "gpu" """Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device. If "disk", keeps images on disk. """ cache_images_type: Literal["uint8", "float32"] = "float32" @@ -80,9 +73,9 @@ class FullImageDatamanagerConfig(DataManagerConfig): fps_reset_every: int = 100 """The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every samples from the pool of all training cameras without replacement before a new round of sampling starts.""" - use_image_train_dataloader: bool = cache_images == "disk" + use_parallel_dataloader: bool = cache_images == "disk" """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" - dataloader_num_workers: int = 8 + dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" prefetch_factor: int = 2 @@ -129,7 +122,7 @@ def __init__( self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") self.train_dataset = self.create_train_dataset() self.eval_dataset = self.create_eval_dataset() - # print(type(self.train_dataset)) # prints InputDataset + if len(self.train_dataset) > 500 and self.config.cache_images == "gpu": CONSOLE.print( "Train dataset has over 500 images, overriding cache_images to cpu", @@ -147,6 +140,10 @@ def __init__( self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] assert len(self.train_unseen_cameras) > 0, "No data found in dataset" + if self.config.use_parallel_dataloader: + import torch.multiprocessing as mp + mp.set_start_method("spawn") + super().__init__() def sample_train_cameras(self): @@ -313,15 +310,45 @@ def get_datapath(self) -> Path: def setup_train(self): """Sets up the data loaders for training""" - + if self.config.use_parallel_dataloader: + self.train_imagebatch_stream = ImageBatchStream( + input_dataset=self.train_dataset, + datamanager_config=self.config, + device=self.device, + ) + self.train_image_dataloader = torch.utils.data.DataLoader( + self.train_imagebatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + collate_fn=identity_collate, + # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? + ) + self.iter_train_image_dataloader = iter(self.train_image_dataloader) + def setup_eval(self): """Sets up the data loader for evaluation""" + if self.config.use_parallel_dataloader: + self.eval_imagebatch_stream = ImageBatchStream( + input_dataset=self.eval_dataset, + datamanager_config=self.config, + device=self.device, + ) + self.eval_image_dataloader = torch.utils.data.DataLoader( + self.eval_imagebatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + collate_fn=identity_collate, + # pin_memory_device=self.device, + ) + self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) @property def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: """ Pretends to be the dataloader for evaluation, it returns a list of (camera, data) tuples """ + if self.config.use_parallel_dataloader: + return self.iter_eval_image_dataloader image_indices = [i for i in range(len(self.eval_dataset))] data = deepcopy(self.cached_eval) _cameras = deepcopy(self.eval_dataset.cameras).to(self.device) @@ -347,6 +374,11 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next training batch Returns a Camera instead of raybundle""" + self.train_count += 1 + if self.config.use_parallel_dataloader: + camera, data = next(self.iter_train_image_dataloader)[0] + return camera, data + image_idx = self.train_unseen_cameras.pop(0) # Make sure to re-populate the unseen cameras list if we have exhausted it if len(self.train_unseen_cameras) == 0: @@ -365,6 +397,11 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next evaluation batch Returns a Camera instead of raybundle""" + self.eval_count += 1 + if self.config.use_parallel_dataloader: + camera, data = next(self.iter_train_image_dataloader)[0] + return camera, data + return self.next_eval_image(step=step) def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: @@ -414,7 +451,6 @@ def _undistort_image( if np.any(distortion_params): newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0) image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore - # print("1:", image.shape) # prints (960, 540, 3) else: newK = K roi = 0, 0, image.shape[1], image.shape[0] @@ -560,11 +596,9 @@ def _undistort_image( K = undist_K.numpy() else: raise NotImplementedError("Only perspective and fisheye cameras are supported") - # print("final:", image.shape, camera.width, camera.height) # prints 'final: (959, 539, 3) tensor([540]) tensor([960])' return K, image, mask -## Let's implement a parallelized splat dataloader! def undistort_view(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: """Undistorts an image to one taken by a linear (pinhole) camera model and returns a new Camera with these updated intrinsics Note: this method does not modify the dataset's attributes at all. @@ -647,68 +681,68 @@ def __iter__(self): i += 1 yield camera, data -class ParallelFullImageDatamanager(FullImageDatamanager, Generic[TDataset]): - def __init__( - self, - config: FullImageDatamanagerConfig, - device: Union[torch.device, str] = "cpu", - test_mode: Literal["test", "val", "inference"] = "val", - world_size: int = 1, - local_rank: int = 0, - **kwargs - ): - import torch.multiprocessing as mp - mp.set_start_method("spawn") - super().__init__( - config=config, - device=device, - test_mode=test_mode, - world_size=world_size, - local_rank=local_rank, - **kwargs - ) +# class ParallelFullImageDatamanager(FullImageDatamanager, Generic[TDataset]): +# def __init__( +# self, +# config: FullImageDatamanagerConfig, +# device: Union[torch.device, str] = "cpu", +# test_mode: Literal["test", "val", "inference"] = "val", +# world_size: int = 1, +# local_rank: int = 0, +# **kwargs +# ): +# import torch.multiprocessing as mp +# mp.set_start_method("spawn") +# super().__init__( +# config=config, +# device=device, +# test_mode=test_mode, +# world_size=world_size, +# local_rank=local_rank, +# **kwargs +# ) - def setup_train(self): - self.train_imagebatch_stream = ImageBatchStream( - input_dataset=self.train_dataset, - datamanager_config=self.config, - device=self.device, - ) - self.train_image_dataloader = torch.utils.data.DataLoader( - self.train_imagebatch_stream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - collate_fn=identity_collate, - # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? - ) - self.iter_train_image_dataloader = iter(self.train_image_dataloader) - - def setup_eval(self): - self.eval_imagebatch_stream = ImageBatchStream( - input_dataset=self.eval_dataset, - datamanager_config=self.config, - device=self.device, - ) - self.eval_image_dataloader = torch.utils.data.DataLoader( - self.eval_imagebatch_stream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - collate_fn=identity_collate, - # pin_memory_device=self.device, - ) - self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) # these things output tuples - - @property - def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: - return self.iter_eval_image_dataloader - - def next_train(self, step: int) -> Tuple[Cameras, Dict]: - self.train_count += 1 - camera, data = next(self.iter_train_image_dataloader)[0] - return camera, data +# def setup_train(self): +# self.train_imagebatch_stream = ImageBatchStream( +# input_dataset=self.train_dataset, +# datamanager_config=self.config, +# device=self.device, +# ) +# self.train_image_dataloader = torch.utils.data.DataLoader( +# self.train_imagebatch_stream, +# batch_size=1, +# num_workers=self.config.dataloader_num_workers, +# collate_fn=identity_collate, +# # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? +# ) +# self.iter_train_image_dataloader = iter(self.train_image_dataloader) + +# def setup_eval(self): +# self.eval_imagebatch_stream = ImageBatchStream( +# input_dataset=self.eval_dataset, +# datamanager_config=self.config, +# device=self.device, +# ) +# self.eval_image_dataloader = torch.utils.data.DataLoader( +# self.eval_imagebatch_stream, +# batch_size=1, +# num_workers=self.config.dataloader_num_workers, +# collate_fn=identity_collate, +# # pin_memory_device=self.device, +# ) +# self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) + +# @property +# def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: +# return self.iter_eval_image_dataloader + +# def next_train(self, step: int) -> Tuple[Cameras, Dict]: +# self.train_count += 1 +# camera, data = next(self.iter_train_image_dataloader)[0] +# return camera, data - def next_eval(self, step: int) -> Tuple[Cameras, Dict]: - self.eval_count += 1 - camera, data = next(self.iter_train_image_dataloader)[0] - return camera, data +# def next_eval(self, step: int) -> Tuple[Cameras, Dict]: +# self.eval_count += 1 +# camera, data = next(self.iter_train_image_dataloader)[0] +# return camera, data \ No newline at end of file From 5c3262b03358ecea4a8a0f33c79f05d6936ff159 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 9 Sep 2024 03:03:02 -0700 Subject: [PATCH 47/78] removed caching of images into bytestrings --- nerfstudio/data/datamanagers/base_datamanager.py | 3 +-- nerfstudio/data/datasets/base_dataset.py | 14 +------------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 8c5d828f39..a0a99e7ce6 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -352,8 +352,7 @@ class VanillaDataManagerConfig(DataManagerConfig): prefetch_factor: int = None """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" - cache_image_bytes: bool = False - """If True, cache raw image files as byte strings to RAM.""" + # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index 2c7a1b1925..d0c5c4857b 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -46,7 +46,7 @@ class InputDataset(Dataset): exclude_batch_keys_from_device: List[str] = ["image", "mask"] cameras: Cameras - def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_images: bool = False): + def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0): super().__init__() self._dataparser_outputs = dataparser_outputs self.scale_factor = scale_factor @@ -55,18 +55,6 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = self.cameras = deepcopy(dataparser_outputs.cameras) self.cameras.rescale_output_resolution(scaling_factor=scale_factor) self.mask_color = dataparser_outputs.metadata.get("mask_color", None) - self.cache_images = cache_images - """If cache_images == True, cache all the image files into RAM in their compressed form (not as tensors yet)""" - if cache_images: - self.binary_images = [] - self.binary_masks = [] - for image_filename in self._dataparser_outputs.image_filenames: - with open(image_filename, 'rb') as f: - self.binary_images.append(io.BytesIO(f.read())) - if self._dataparser_outputs.mask_filenames is not None: - for mask_filename in self._dataparser_outputs.mask_filenames: - with open(mask_filename, 'rb') as f: - self.binary_masks.append(io.BytesIO(f.read())) def __len__(self): return len(self._dataparser_outputs.image_filenames) From ff2bda1789a8f06623a29aa717d6c1a62cba29b5 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 9 Sep 2024 15:52:50 -0700 Subject: [PATCH 48/78] adding caching of compressed images to RAM, forgot that hardware matters --- nerfstudio/data/datamanagers/base_datamanager.py | 3 ++- nerfstudio/data/datasets/base_dataset.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index a0a99e7ce6..8c5d828f39 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -352,7 +352,8 @@ class VanillaDataManagerConfig(DataManagerConfig): prefetch_factor: int = None """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" - + cache_image_bytes: bool = False + """If True, cache raw image files as byte strings to RAM.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index d0c5c4857b..2c7a1b1925 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -46,7 +46,7 @@ class InputDataset(Dataset): exclude_batch_keys_from_device: List[str] = ["image", "mask"] cameras: Cameras - def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0): + def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_images: bool = False): super().__init__() self._dataparser_outputs = dataparser_outputs self.scale_factor = scale_factor @@ -55,6 +55,18 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = self.cameras = deepcopy(dataparser_outputs.cameras) self.cameras.rescale_output_resolution(scaling_factor=scale_factor) self.mask_color = dataparser_outputs.metadata.get("mask_color", None) + self.cache_images = cache_images + """If cache_images == True, cache all the image files into RAM in their compressed form (not as tensors yet)""" + if cache_images: + self.binary_images = [] + self.binary_masks = [] + for image_filename in self._dataparser_outputs.image_filenames: + with open(image_filename, 'rb') as f: + self.binary_images.append(io.BytesIO(f.read())) + if self._dataparser_outputs.mask_filenames is not None: + for mask_filename in self._dataparser_outputs.mask_filenames: + with open(mask_filename, 'rb') as f: + self.binary_masks.append(io.BytesIO(f.read())) def __len__(self): return len(self._dataparser_outputs.image_filenames) From f6dd7dd89c0d0b7fe30d75a1b95e44290167c428 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 14 Sep 2024 18:01:16 -0700 Subject: [PATCH 49/78] removing oom methods, adding the ability to add a flag to dataloading --- nerfstudio/configs/method_configs.py | 97 +--------------------------- 1 file changed, 1 insertion(+), 96 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 93896dc5fb..e9f98b274a 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -95,6 +95,7 @@ dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=4096, eval_num_rays_per_batch=4096, + load_from_disk=True, ), model=NerfactoModelConfig( eval_num_rays_per_chunk=1 << 15, @@ -212,48 +213,6 @@ vis="viewer", ) -method_configs["nerfacto-oom"] = TrainerConfig( - method_name="nerfacto-oom", - steps_per_eval_batch=500, - steps_per_save=2000, - max_num_iterations=30000, - mixed_precision=True, - pipeline=VanillaPipelineConfig( - datamanager=VanillaDataManagerConfig( - dataparser=NerfstudioDataParserConfig(), - train_num_rays_per_batch=4096, - eval_num_rays_per_batch=4096, - use_parallel_dataloader=True, - load_from_disk=True, - dataloader_num_workers=4, - prefetch_factor=10, - train_num_images_to_sample_from=50, - train_num_times_to_repeat_images=10, - ), - model=NerfactoModelConfig( - eval_num_rays_per_chunk=1 << 15, - average_init_density=0.01, - camera_optimizer=CameraOptimizerConfig(mode="SO3xR3"), - ), - ), - optimizers={ - "proposal_networks": { - "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), - "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000), - }, - "fields": { - "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), - "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000), - }, - "camera_opt": { - "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), - "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=5000), - }, - }, - viewer=ViewerConfig(num_rays_per_chunk=1 << 15), - vis="viewer", -) - method_configs["depth-nerfacto"] = TrainerConfig( method_name="depth-nerfacto", steps_per_eval_batch=500, @@ -743,60 +702,6 @@ vis="viewer", ) -method_configs["splatfacto-oom"] = TrainerConfig( - method_name="splatfacto-oom", - steps_per_eval_image=100, - steps_per_eval_batch=0, - steps_per_save=2000, - steps_per_eval_all_images=1000, - max_num_iterations=30000, - mixed_precision=False, - pipeline=VanillaPipelineConfig( - datamanager=FullImageDatamanagerConfig( - # _target=ParallelFullImageDatamanager[InputDataset], - dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), - # dataparser=NerfstudioDataParserConfig(load_3D_points=True), - cache_images_type="uint8", - use_parallel_dataloader=True, - ), - model=SplatfactoModelConfig(), - ), - optimizers={ - "means": { - "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), - "scheduler": ExponentialDecaySchedulerConfig( - lr_final=1.6e-6, - max_steps=30000, - ), - }, - "features_dc": { - "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), - "scheduler": None, - }, - "features_rest": { - "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), - "scheduler": None, - }, - "opacities": { - "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), - "scheduler": None, - }, - "scales": { - "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), - "scheduler": None, - }, - "quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None}, - "camera_opt": { - "optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15), - "scheduler": ExponentialDecaySchedulerConfig( - lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 - ), - }, - }, - viewer=ViewerConfig(num_rays_per_chunk=1 << 15), - vis="viewer", -) - def merge_methods(methods, method_descriptions, new_methods, new_descriptions, overwrite=True): """Merge new methods and descriptions into existing methods and descriptions. From a6602c71f223801ba994e50921bdf0cfb5542bb6 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 14 Sep 2024 18:02:45 -0700 Subject: [PATCH 50/78] removed CacheDataloader, moved RayBatchStream to dataloaders.py, new vanilla_datamanager rewritten --- .../data/datamanagers/base_datamanager.py | 349 +++-------------- .../data/datamanagers/datamanager_configs.py | 34 ++ .../data/datamanagers/parallel_datamanager.py | 2 +- nerfstudio/data/utils/dataloaders.py | 362 ++++++++++++++---- 4 files changed, 384 insertions(+), 363 deletions(-) create mode 100644 nerfstudio/data/datamanagers/datamanager_configs.py diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 8c5d828f39..3a91cd3749 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -57,7 +57,8 @@ from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig from nerfstudio.data.utils.dataloaders import ( - CacheDataloader, + # CacheDataloader, + RayBatchStream, FixedIndicesEvalDataloader, RandIndicesEvalDataloader, ) @@ -341,7 +342,7 @@ class VanillaDataManagerConfig(DataManagerConfig): along with relevant information about camera intrinsics""" patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - use_parallel_dataloader: bool = False + use_parallel_dataloader: bool = True """Allows parallelization of the dataloading process with multiple workers prefetching RayBundles.""" load_from_disk: bool = False """If True, conserves RAM memory by loading images from disk. @@ -352,8 +353,6 @@ class VanillaDataManagerConfig(DataManagerConfig): prefetch_factor: int = None """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" - cache_image_bytes: bool = False - """If True, cache raw image files as byte strings to RAM.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -378,191 +377,17 @@ def __post_init__(self): if self.load_from_disk: self.train_num_images_to_sample_from = 50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from self.train_num_times_to_repeat_images = 10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images - self.prefetch_factor = self.train_num_times_to_repeat_images - -TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) - -import concurrent.futures -import math -import multiprocessing -import random -from typing import Sized - -from torch.utils.data import Dataset - -from nerfstudio.utils.misc import get_dict_to_torch -from tqdm.auto import tqdm - -from torch.profiler import profile, record_function, ProfilerActivity - -class RayBatchStream(torch.utils.data.IterableDataset): - """Wrapper around Pytorch's IterableDataset to gerenate the next batch of rays and corresponding labels - with multiple parallel workers. - - Each worker samples a small batch of images, pixel samples those images, and generates rays for one training step. - The same batch of images can be pixel sampled multiple times hasten ray generation, as retrieving images is process - bottlenecked by disk read speed. To avoid Out-Of-Memory (OOM) errors, this batch of images is small and regenerated - by resampling the worker's partition of images to gurantee sampling diversity. - """ - def __init__( - self, - input_dataset: Dataset, - datamanager_config: DataManagerConfig, - device: Union[torch.device, str] = "cpu", - exclude_batch_keys_from_device: Optional[List[str]] = None, - num_image_load_threads: int = 4, - ): - if exclude_batch_keys_from_device is None: - exclude_batch_keys_from_device = ["image"] - self.input_dataset = input_dataset - assert isinstance(self.input_dataset, Sized) - - self.datamanager_config = datamanager_config - self.num_images_to_sample_from = self.datamanager_config.train_num_images_to_sample_from - self.num_times_to_repeat_images = self.datamanager_config.train_num_times_to_repeat_images - self.device = device - self.collate_fn = variable_res_collate # variable_res_collate avoids np.stack'ing images, which allows it to be much faster than `nerfstudio_collate` - self.num_image_load_threads = num_image_load_threads - """Number of threads created to read images from disk and form collated batches.""" - self.exclude_batch_keys_from_device = exclude_batch_keys_from_device - """Which key of the batch (such as 'image', 'mask','depth') to prevent from moving to the device. - For instance, if you would like to conserve GPU memory, don't move the image tensors to the GPU, - which comes at a cost of total training time. The default value is ['image'] - """ - # print("self.exclude_batch_keys_from_device", self.exclude_batch_keys_from_device) # usually prints ['image'] - - self.pixel_sampler: PixelSampler = None - self.ray_generator: RayGenerator = None - - self.enable_per_worker_image_caching = self.datamanager_config.load_from_disk == False - """If True, each worker's will cache its entire partition of the image dataset as image tensors in RAM.""" - self._cached_collated_batch = None - """self._cached_collated_batch contains a collated batch of images for a specific worker that's ready for pixel sampling.""" - - def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler: - """copy-pasted from VanillaDataManager.""" - from nerfstudio.cameras.cameras import CameraType - from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSamplerConfig - - if self.datamanager_config.patch_size > 1 and type(self.datamanager_config.pixel_sampler) is PixelSamplerConfig: - return PatchPixelSamplerConfig().setup( - patch_size=self.datamanager_config.patch_size, num_rays_per_batch=num_rays_per_batch - ) - is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() - if is_equirectangular.any(): - CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") - - fisheye_crop_radius = None - if dataset.cameras.metadata is not None: - fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") - - return self.datamanager_config.pixel_sampler.setup( - is_equirectangular=is_equirectangular, - num_rays_per_batch=num_rays_per_batch, - fisheye_crop_radius=fisheye_crop_radius, - ) - - def _get_batch_list(self, indices=None): - """Returns a list representing a single batch from the dataset attribute. - Each item of the list is a dictionary with dict_keys(['image_idx', 'image']) representing 1 image. - This function is used to sample and load images from disk/RAM and is only called in _get_collated_batch - The length of the list is equal to the (# of training images) / (num_workers)""" - - assert isinstance(self.input_dataset, Sized) - if indices is None: - # Note: self.num_images_to_sample_from is usually -1, but _get_batch_list is usually called with indices != None. - # _get_batch_list is used by _get_collated_batch, whose indices = some partition of the dataset - indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) - batch_list = [] - results = [] - - # num_threads = int(self.num_ds_load_threads) * 4 - num_threads = ( - int(self.num_image_load_threads) - if not self.enable_per_worker_image_caching - else 4 * int(self.num_image_load_threads) - ) - num_threads = min(num_threads, multiprocessing.cpu_count() - 1) - num_threads = max(num_threads, 1) # print('num_threads', num_threads) # prints 16 - - # NB: this is I/O heavy because we are going to disk and reading an image filename - # hence multi-threaded inside the worker - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - for idx in indices: - res = executor.submit(self.input_dataset.__getitem__, idx) - results.append(res) - results = tqdm(results) # this is temporary - for res in results: - batch_list.append(res.result()) + self.prefetch_factor = self.train_num_times_to_repeat_images if self.use_parallel_dataloader else None + + if self.use_parallel_dataloader: + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + pass + self.dataloader_num_workers = 4 - return batch_list - def _get_collated_batch(self, indices=None): - """Takes the output of _get_batch_list and collates them with nerfstudio_collate() or variable_res_collate() - Note: dict is an instance of collections.abc.Mapping - - The resulting output is collated_batch: a dictionary with dict_keys(['image_idx', 'image']) - collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) - collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) - """ - batch_list = self._get_batch_list(indices=indices) - collated_batch = self.collate_fn(batch_list) - collated_batch = get_dict_to_torch( - collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device - ) - return collated_batch - - def __iter__(self): - """This implementation allows every worker only cache the indices of the images they will use to generate rays to conserve RAM memory.""" - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: # if we have multiple processes - per_worker = int(math.ceil(len(self.input_dataset) / float(worker_info.num_workers))) - slice_start = worker_info.id * per_worker - else: # we only have a single process - per_worker = len(self.input_dataset) - slice_start = 0 - dataset_indices = list( - range(len(self.input_dataset)) - ) - worker_indices = dataset_indices[ - slice_start : slice_start + per_worker - ] # the indices of the datapoints in the dataset this worker will load - if self.enable_per_worker_image_caching: - self._cached_collated_batch = self._get_collated_batch(worker_indices) - r = random.Random(3301) - num_rays_per_loop = self.datamanager_config.train_num_rays_per_batch # default train_num_rays_per_batch is 4096 - # each worker has its own pixel sampler - worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) - self.ray_generator = RayGenerator(self.input_dataset.cameras)#.to(self.device)) - - i = 0 - while True: - if self.enable_per_worker_image_caching: - collated_batch = self._cached_collated_batch - elif i % self.num_times_to_repeat_images == 0: - r.shuffle(worker_indices) - - if self.num_images_to_sample_from == -1: # if -1, the worker gets all available indices in its partition - image_indices = worker_indices - else: # get a total of 'num_images_to_sample_from' image indices - image_indices = worker_indices[:self.num_images_to_sample_from] - - collated_batch = self._get_collated_batch(image_indices) - i += 1 - """ - Here, the variable 'batch' refers to the output of our pixel sampler. - - batch is a dict_keys(['image', 'indices']) - - batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch. Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through. The info of what specific image index that pixel belongs to is stored within batch[’indices’] - - batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_index=camera_index, pixelRow, pixelCol) - What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask, - and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol) - """ - batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. - # the returned batch of pixels also somehow moves the images from the CPU to the GPU - # collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device == True - ray_indices = batch["indices"] - ray_bundle = self.ray_generator(ray_indices) # the ray_bundle is on the GPU, but batch["image"] is on the CPU - yield ray_bundle, batch +TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) class VanillaDataManager(DataManager, Generic[TDataset]): @@ -669,96 +494,50 @@ def create_eval_dataset(self) -> TDataset: scale_factor=self.config.camera_res_scale_factor, ) - def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler: - """Infer pixel sampler to use.""" - if self.config.patch_size > 1 and type(self.config.pixel_sampler) is PixelSamplerConfig: - return PatchPixelSamplerConfig().setup( - patch_size=self.config.patch_size, num_rays_per_batch=num_rays_per_batch - ) - is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() - if is_equirectangular.any(): - CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") - - fisheye_crop_radius = None - if dataset.cameras.metadata is not None: - fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") - - return self.config.pixel_sampler.setup( - is_equirectangular=is_equirectangular, - num_rays_per_batch=num_rays_per_batch, - fisheye_crop_radius=fisheye_crop_radius, - ) def setup_train(self): - """Sets up the data loaders for training""" - assert self.train_dataset is not None - CONSOLE.print("Setting up training dataset...") - - if self.config.use_parallel_dataloader: - import torch.multiprocessing as mp - mp.set_start_method("spawn") - self.raybatch_stream = RayBatchStream( - input_dataset=self.train_dataset, - datamanager_config=self.config, - device=self.device, - ) - self.ray_dataloader = torch.utils.data.DataLoader( - self.raybatch_stream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - prefetch_factor=self.config.prefetch_factor, - shuffle=False, - # pin_memory=True, - # Our dataset handles batching / collation of rays - collate_fn=identity_collate, - pin_memory_device=self.device, - ) - self.iter_train_raybundles = iter(self.ray_dataloader) - else: - self.iter_train_raybundles = None - self.train_image_dataloader = CacheDataloader( - self.train_dataset, - num_images_to_sample_from=self.config.train_num_images_to_sample_from, # batch_size - num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, # - device=self.device, - num_workers=self.world_size * 4 - if self.config.dataloader_num_workers == -1 - else self.config.dataloader_num_workers, - prefetch_factor=2 - if self.config.prefetch_factor == -1 - else self.config.prefetch_factor, - pin_memory=True, - collate_fn=self.config.collate_fn, - exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, - ) - self.iter_train_image_dataloader = iter(self.train_image_dataloader) - self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) - self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device)) - - def setup_eval(self): - """Sets up the data loader for evaluation""" - assert self.eval_dataset is not None - CONSOLE.print("Setting up evaluation dataset...") - self.eval_image_dataloader = CacheDataloader( - self.eval_dataset, - num_images_to_sample_from=self.config.eval_num_images_to_sample_from, - num_times_to_repeat_images=self.config.eval_num_times_to_repeat_images, + self.train_raybatchstream = RayBatchStream( + input_dataset=self.train_dataset, + num_rays_per_batch=self.config.train_num_rays_per_batch, + num_images_to_sample_from=self.config.train_num_images_to_sample_from, + num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - num_workers=self.world_size * 4, - pin_memory=True, - collate_fn=self.config.collate_fn, - exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, + collate_fn = variable_res_collate, + load_from_disk = True, + ) + self.train_ray_dataloader = torch.utils.data.DataLoader( + self.train_raybatchstream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + prefetch_factor=self.config.prefetch_factor, + shuffle=False, + # pin_memory=True, + collate_fn=identity_collate, # Our dataset handles batching / collation of rays + pin_memory_device=self.device, ) - self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) - self.eval_pixel_sampler = self._get_pixel_sampler(self.eval_dataset, self.config.eval_num_rays_per_batch) - self.eval_ray_generator = RayGenerator(self.eval_dataset.cameras.to(self.device)) - # for loading full images - self.fixed_indices_eval_dataloader = FixedIndicesEvalDataloader( + self.iter_train_raybundles = iter(self.train_ray_dataloader) + + def setup_eval(self): + self.eval_raybatchstream = RayBatchStream( input_dataset=self.eval_dataset, + num_rays_per_batch=self.config.train_num_rays_per_batch, + num_images_to_sample_from=self.config.train_num_images_to_sample_from, + num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - num_workers=self.world_size * 4, + collate_fn = variable_res_collate, + load_from_disk = True, + ) + self.eval_ray_dataloader = torch.utils.data.DataLoader( + self.eval_raybatchstream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + prefetch_factor=self.config.prefetch_factor, + shuffle=False, + collate_fn=identity_collate, # Our dataset handles batching / collation of rays + pin_memory_device=self.device, ) - self.eval_dataloader = RandIndicesEvalDataloader( + self.iter_eval_raybundles = iter(self.eval_ray_dataloader) + self.image_eval_dataloader = RandIndicesEvalDataloader( input_dataset=self.eval_dataset, device=self.device, num_workers=self.world_size * 4, @@ -767,33 +546,23 @@ def setup_eval(self): def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" self.train_count += 1 - if self.config.use_parallel_dataloader: - ret = next(self.iter_train_raybundles) - assert len(ret) == 1, f"batch size should be one {len(ret)}" - ray_bundle, batch = ret[0] - ray_bundle = ray_bundle.to(self.device) - else: - image_batch = next(self.iter_train_image_dataloader) - assert self.train_pixel_sampler is not None - assert isinstance(image_batch, dict) - batch = self.train_pixel_sampler.sample(image_batch) - ray_indices = batch["indices"] - ray_bundle = self.train_ray_generator(ray_indices) + ret = next(self.iter_train_raybundles) + assert len(ret) == 1, f"batch size should be one" + ray_bundle, batch = ret[0] + ray_bundle = ray_bundle.to(self.device) return ray_bundle, batch def next_eval(self, step: int) -> Tuple[RayBundle, Dict]: - """Returns the next batch of data from the eval dataloader.""" + """Returns the next batch of data from the train dataloader.""" self.eval_count += 1 - image_batch = next(self.iter_eval_image_dataloader) - assert self.eval_pixel_sampler is not None - assert isinstance(image_batch, dict) - batch = self.eval_pixel_sampler.sample(image_batch) - ray_indices = batch["indices"] - ray_bundle = self.eval_ray_generator(ray_indices) + ret = next(self.iter_eval_raybundles) + assert len(ret) == 1, f"batch size should be one {len(ret)}" + ray_bundle, batch = ret[0] + ray_bundle = ray_bundle.to(self.device) return ray_bundle, batch def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: - for camera, batch in self.eval_dataloader: + for camera, batch in self.image_eval_dataloader: assert camera.shape[0] == 1 return camera, batch raise ValueError("No more eval images") diff --git a/nerfstudio/data/datamanagers/datamanager_configs.py b/nerfstudio/data/datamanagers/datamanager_configs.py new file mode 100644 index 0000000000..960ce8f8f0 --- /dev/null +++ b/nerfstudio/data/datamanagers/datamanager_configs.py @@ -0,0 +1,34 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration classes for our datamanagers. +""" + +@dataclass +class DataManagerConfig(InstantiateConfig): + """Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers; + After instantiation, data manager holds both train/eval datasets and is in charge of returning unpacked + train/eval data at each iteration + """ + + _target: Type = field(default_factory=lambda: DataManager) + """Target class to instantiate.""" + data: Optional[Path] = None + """Source of data, may not be used by all models.""" + masks_on_gpu: bool = False + """Process masks on GPU for speed at the expense of memory, if True.""" + images_on_gpu: bool = False + """Process images on GPU for speed at the expense of memory, if True.""" + diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index b28e530f91..87c8f9f87a 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -41,7 +41,7 @@ from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader +from nerfstudio.data.utils.dataloaders import FixedIndicesEvalDataloader, RandIndicesEvalDataloader #,CacheDataloader from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 427792ca86..30ad214a89 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -21,7 +21,8 @@ import multiprocessing import random from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Sized, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sized, Tuple, Union, cast +from dataclasses import field import torch from rich.progress import track @@ -34,93 +35,277 @@ from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate from nerfstudio.utils.misc import get_dict_to_torch from nerfstudio.utils.rich_utils import CONSOLE +from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig +from nerfstudio.model_components.ray_generators import RayGenerator -class CacheDataloader(DataLoader): - """Collated image dataset that implements caching of default-pytorch-collatable data. - Creates batches of the InputDataset return type. - +def variable_res_collate(batch: List[Dict]) -> Dict: + """Default collate function for the cached dataloader. Args: - dataset: Dataset to sample from. - num_samples_to_collate: How many images to sample rays for each batch. -1 for all images. - num_times_to_repeat_images: How often to yield an image batch before resampling. -1 to never pick new images. - device: Device to perform computation. - collate_fn: The function we will use to collate our training data + batch: Batch of samples from the dataset. + Returns: + Collated batch. """ + images = [] + imgdata_lists = defaultdict(list) + for data in batch: + image = data.pop("image") + images.append(image) + topop = [] + for key, val in data.items(): + if isinstance(val, torch.Tensor): + # if the value has same height and width as the image, assume that it should be collated accordingly. + if len(val.shape) >= 2 and val.shape[:2] == image.shape[:2]: + imgdata_lists[key].append(val) + topop.append(key) + # now that iteration is complete, the image data items can be removed from the batch + for key in topop: + del data[key] + new_batch = nerfstudio_collate(batch) + new_batch["image"] = images + new_batch.update(imgdata_lists) + + return new_batch + + +# class CacheDataloader(DataLoader): +# """Collated image dataset that implements caching of default-pytorch-collatable data. +# Creates batches of the InputDataset return type. + +# Args: +# dataset: Dataset to sample from. +# num_samples_to_collate: How many images to sample rays for each batch. -1 for all images. +# num_times_to_repeat_images: How often to yield an image batch before resampling. -1 to never pick new images. +# device: Device to perform computation. +# collate_fn: The function we will use to collate our training data +# """ + +# def __init__( +# self, +# dataset: Dataset, +# num_images_to_sample_from: int = -1, +# num_times_to_repeat_images: int = -1, +# device: Union[torch.device, str] = "cpu", +# collate_fn: Callable[[Any], Any] = nerfstudio_collate, +# exclude_batch_keys_from_device: Optional[List[str]] = None, +# **kwargs, +# ): +# if exclude_batch_keys_from_device is None: +# exclude_batch_keys_from_device = ["image"] +# self.dataset = dataset +# assert isinstance(self.dataset, Sized) + +# super().__init__(dataset=dataset, **kwargs) # This will set self.dataset +# self.num_times_to_repeat_images = num_times_to_repeat_images +# self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) +# self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from +# self.device = device +# self.collate_fn = collate_fn +# self.num_workers = kwargs.get("num_workers", 0) +# self.exclude_batch_keys_from_device = exclude_batch_keys_from_device + +# self.num_repeated = self.num_times_to_repeat_images # starting value +# self.first_time = True + +# self.cached_collated_batch = None +# if self.cache_all_images: +# CONSOLE.print(f"Caching all {len(self.dataset)} images.") +# if len(self.dataset) > 500: +# CONSOLE.print( +# "[bold yellow]Warning: If you run out of memory, try reducing the number of images to sample from." +# ) +# self.cached_collated_batch = self._get_collated_batch() +# elif self.num_times_to_repeat_images == -1: +# CONSOLE.print( +# f"Caching {self.num_images_to_sample_from} out of {len(self.dataset)} images, without resampling." +# ) +# else: +# CONSOLE.print( +# f"Caching {self.num_images_to_sample_from} out of {len(self.dataset)} images, " +# f"resampling every {self.num_times_to_repeat_images} iters." +# ) + +# def __getitem__(self, idx): +# return self.dataset.__getitem__(idx) + +# def _get_batch_list(self): +# """Returns a list of batches from the dataset attribute.""" + +# assert isinstance(self.dataset, Sized) +# indices = random.sample(range(len(self.dataset)), k=self.num_images_to_sample_from) +# batch_list = [] +# results = [] + +# num_threads = int(self.num_workers) * 4 +# num_threads = min(num_threads, multiprocessing.cpu_count() - 1) +# num_threads = max(num_threads, 1) + +# with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: +# for idx in indices: +# res = executor.submit(self.dataset.__getitem__, idx) +# results.append(res) + +# for res in track(results, description="Loading data batch", transient=True): +# batch_list.append(res.result()) + +# return batch_list + +# def _get_collated_batch(self): +# """Returns a collated batch of images.""" +# batch_list = self._get_batch_list() +# collated_batch = self.collate_fn(batch_list) +# collated_batch = get_dict_to_torch( +# collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device +# ) +# return collated_batch + +# def __iter__(self): +# while True: +# if self.cache_all_images: +# collated_batch = self.cached_collated_batch +# elif self.first_time or ( +# self.num_times_to_repeat_images != -1 and self.num_repeated >= self.num_times_to_repeat_images +# ): +# # trigger a reset +# self.num_repeated = 0 +# collated_batch = self._get_collated_batch() +# # possibly save a cached item +# self.cached_collated_batch = collated_batch if self.num_times_to_repeat_images != 0 else None +# self.first_time = False +# else: +# collated_batch = self.cached_collated_batch +# self.num_repeated += 1 +# yield collated_batch + +import concurrent.futures +import math +import multiprocessing +import random +from typing import Sized +from torch.utils.data import Dataset +from nerfstudio.utils.misc import get_dict_to_torch +from tqdm.auto import tqdm +class RayBatchStream(torch.utils.data.IterableDataset): + """Wrapper around Pytorch's IterableDataset to generate the next batch of rays (next RayBundle) and corresponding labels + with multiple parallel workers. + + Each worker samples a small batch of images, pixel samples those images, and generates rays for one training step. + The same batch of images can be pixel sampled multiple times hasten ray generation, as retrieving images is process + bottlenecked by disk read speed. To avoid Out-Of-Memory (OOM) errors, this batch of images is small and regenerated + by resampling the worker's partition of images to maintain sampling diversity. + """ def __init__( self, - dataset: Dataset, + input_dataset: Dataset, + num_rays_per_batch: int = 1024, num_images_to_sample_from: int = -1, num_times_to_repeat_images: int = -1, device: Union[torch.device, str] = "cpu", - collate_fn: Callable[[Any], Any] = nerfstudio_collate, + # variable_res_collate avoids np.stack'ing images, which allows it to be much faster than `nerfstudio_collate` + collate_fn: Callable[[Any], Any] = cast(Any, staticmethod(variable_res_collate)), + num_image_load_threads: int = 4, exclude_batch_keys_from_device: Optional[List[str]] = None, - **kwargs, + load_from_disk: bool = False, + patch_size: int = 1, + ): if exclude_batch_keys_from_device is None: exclude_batch_keys_from_device = ["image"] - self.dataset = dataset - assert isinstance(self.dataset, Sized) - - super().__init__(dataset=dataset, **kwargs) # This will set self.dataset + self.input_dataset = input_dataset + assert isinstance(self.input_dataset, Sized) + self.num_rays_per_batch = num_rays_per_batch + """Number of rays per batch to user per training iteration.""" + self.num_images_to_sample_from = num_images_to_sample_from + """How many images to sample to generate a RayBundle. More images means greater sampling diversity at expense of increased RAM usage.""" self.num_times_to_repeat_images = num_times_to_repeat_images - self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) - self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from + """How many RayBundles to generate from this batch of images after sampling `num_images_to_sample_from` images.""" self.device = device - self.collate_fn = collate_fn - self.num_workers = kwargs.get("num_workers", 0) + """If a CUDA GPU is present, self.device will be set to use that GPU.""" + self.collate_fn = collate_fn + """What collate function is used to batch images to be used for pixel sampling and ray generation. """ + self.num_image_load_threads = num_image_load_threads + """Number of threads created to read images from disk and form collated batches.""" self.exclude_batch_keys_from_device = exclude_batch_keys_from_device - - self.num_repeated = self.num_times_to_repeat_images # starting value - self.first_time = True - - self.cached_collated_batch = None - if self.cache_all_images: - CONSOLE.print(f"Caching all {len(self.dataset)} images.") - if len(self.dataset) > 500: - CONSOLE.print( - "[bold yellow]Warning: If you run out of memory, try reducing the number of images to sample from." - ) - self.cached_collated_batch = self._get_collated_batch() - elif self.num_times_to_repeat_images == -1: - CONSOLE.print( - f"Caching {self.num_images_to_sample_from} out of {len(self.dataset)} images, without resampling." - ) - else: - CONSOLE.print( - f"Caching {self.num_images_to_sample_from} out of {len(self.dataset)} images, " - f"resampling every {self.num_times_to_repeat_images} iters." + """Which key of the batch (such as 'image', 'mask','depth') to prevent from moving to the device. + For instance, if you would like to conserve GPU memory, don't move the image tensors to the GPU, + which comes at a cost of total training time. The default value is ['image'].""" + self.load_from_disk = load_from_disk + self.patch_size = patch_size + """Size of patch to sample from. If > 1, patch-based sampling will be used.""" + self.enable_per_worker_image_caching = load_from_disk == False + """If True, each worker's will cache its entire partition of the image dataset as image tensors in RAM.""" + self._cached_collated_batch = None + """Each worker has a self._cached_collated_batch contains a collated batch of images cached in RAM for a specific worker that's ready for pixel sampling.""" + self.pixel_sampler_config: PixelSamplerConfig = PixelSamplerConfig() + """Specifies the pixel sampler config used to sample pixels from images. Each worker will have its own pixel sampler""" + self.ray_generator: RayGenerator = None + """Each worker will have its own ray generator, so this is set to None for now.""" + + def _get_pixel_sampler(self, dataset: Dataset, num_rays_per_batch: int) -> PixelSampler: + """copied from VanillaDataManager.""" + from nerfstudio.cameras.cameras import CameraType + + if self.patch_size > 1 and type(self.pixel_sampler_config) is PixelSamplerConfig: + return PatchPixelSamplerConfig().setup( + patch_size=self.patch_size, num_rays_per_batch=num_rays_per_batch ) + is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() + if is_equirectangular.any(): + CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") + + fisheye_crop_radius = None + if dataset.cameras.metadata is not None: + fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") + + return self.pixel_sampler_config.setup( + is_equirectangular=is_equirectangular, + num_rays_per_batch=num_rays_per_batch, + fisheye_crop_radius=fisheye_crop_radius, + ) - def __getitem__(self, idx): - return self.dataset.__getitem__(idx) - - def _get_batch_list(self): - """Returns a list of batches from the dataset attribute.""" - - assert isinstance(self.dataset, Sized) - indices = random.sample(range(len(self.dataset)), k=self.num_images_to_sample_from) + def _get_batch_list(self, indices=None): + """Returns a list representing a single batch from the dataset attribute. + Each item of the list is a dictionary with dict_keys(['image_idx', 'image']) representing 1 image. + This function is used to sample and load images from disk/RAM and is only called in _get_collated_batch + The length of the list is equal to the (# of training images) / (num_workers)""" + + assert isinstance(self.input_dataset, Sized) + if indices is None: + # Note: self.num_images_to_sample_from is usually -1, but _get_batch_list is usually called with indices != None. + # _get_batch_list is used by _get_collated_batch, whose indices = some partition of the dataset + indices = random.sample(range(len(self.input_dataset)), k=self.num_images_to_sample_from) batch_list = [] results = [] - num_threads = int(self.num_workers) * 4 + num_threads = ( + int(self.num_image_load_threads) + if not self.enable_per_worker_image_caching + else 4 * int(self.num_image_load_threads) + ) num_threads = min(num_threads, multiprocessing.cpu_count() - 1) num_threads = max(num_threads, 1) + # NB: this is I/O heavy because we are going to disk and reading an image filename + # hence multi-threaded inside the worker with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: for idx in indices: - res = executor.submit(self.dataset.__getitem__, idx) + res = executor.submit(self.input_dataset.__getitem__, idx) results.append(res) - - for res in track(results, description="Loading data batch", transient=True): + for res in results: batch_list.append(res.result()) - + return batch_list - def _get_collated_batch(self): - """Returns a collated batch.""" - batch_list = self._get_batch_list() + def _get_collated_batch(self, indices=None): + """Takes the output of _get_batch_list and collates them with nerfstudio_collate() or variable_res_collate() + Note: dict is an instance of collections.abc.Mapping + + The resulting output is collated_batch: a dictionary with dict_keys(['image_idx', 'image']) + collated_batch['image_idx'] is tensor with shape torch.Size([per_worker]) + collated_batch['image'] is tensor with shape torch.Size([per_worker, height, width, 3]) + """ + batch_list = self._get_batch_list(indices=indices) collated_batch = self.collate_fn(batch_list) collated_batch = get_dict_to_torch( collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device @@ -128,22 +313,55 @@ def _get_collated_batch(self): return collated_batch def __iter__(self): + """This implementation allows every worker only cache the indices of the images they will use to generate rays to conserve RAM memory.""" + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: # if we have multiple processes + per_worker = int(math.ceil(len(self.input_dataset) / float(worker_info.num_workers))) + slice_start = worker_info.id * per_worker + else: # we only have a single process + per_worker = len(self.input_dataset) + slice_start = 0 + dataset_indices = list( + range(len(self.input_dataset)) + ) + worker_indices = dataset_indices[ + slice_start : slice_start + per_worker + ] # the indices of the datapoints in the dataset this worker will load + if self.enable_per_worker_image_caching: + self._cached_collated_batch = self._get_collated_batch(worker_indices) + r = random.Random(3301) + num_rays_per_loop = self.num_rays_per_batch # default train_num_rays_per_batch is 4096 + # each worker has its own pixel sampler + worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) + self.ray_generator = RayGenerator(self.input_dataset.cameras) + + i = 0 while True: - if self.cache_all_images: - collated_batch = self.cached_collated_batch - elif self.first_time or ( - self.num_times_to_repeat_images != -1 and self.num_repeated >= self.num_times_to_repeat_images - ): - # trigger a reset - self.num_repeated = 0 - collated_batch = self._get_collated_batch() - # possibly save a cached item - self.cached_collated_batch = collated_batch if self.num_times_to_repeat_images != 0 else None - self.first_time = False - else: - collated_batch = self.cached_collated_batch - self.num_repeated += 1 - yield collated_batch + if self.enable_per_worker_image_caching: + collated_batch = self._cached_collated_batch + elif i % self.num_times_to_repeat_images == 0: + r.shuffle(worker_indices) + + if self.num_images_to_sample_from == -1: # if -1, the worker gets all available indices in its partition + image_indices = worker_indices + else: # get a total of 'num_images_to_sample_from' image indices + image_indices = worker_indices[:self.num_images_to_sample_from] + + collated_batch = self._get_collated_batch(image_indices) + i += 1 + """ + Here, the variable 'batch' refers to the output of our pixel sampler. + - batch is a dict_keys(['image', 'indices']) + - batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch. Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through. The info of what specific image index that pixel belongs to is stored within batch[’indices’] + - batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_index=camera_index, pixelRow, pixelCol) + What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask, + and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol) + """ + batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. + # collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image' + ray_indices = batch["indices"] + ray_bundle = self.ray_generator(ray_indices).to(self.device) # the ray_bundle is on the GPU; batch["image"] is on the CPU + yield ray_bundle, batch class EvalDataloader(DataLoader): From 3dc20316e2b0a635d1c1ede081ec14af32dacc57 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 14 Sep 2024 18:07:29 -0700 Subject: [PATCH 51/78] fixing base_piplines, deleting a weird datamanager_configs file that was accidently created --- .../data/datamanagers/datamanager_configs.py | 34 ------------------- nerfstudio/pipelines/base_pipeline.py | 1 - 2 files changed, 35 deletions(-) delete mode 100644 nerfstudio/data/datamanagers/datamanager_configs.py diff --git a/nerfstudio/data/datamanagers/datamanager_configs.py b/nerfstudio/data/datamanagers/datamanager_configs.py deleted file mode 100644 index 960ce8f8f0..0000000000 --- a/nerfstudio/data/datamanagers/datamanager_configs.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Configuration classes for our datamanagers. -""" - -@dataclass -class DataManagerConfig(InstantiateConfig): - """Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers; - After instantiation, data manager holds both train/eval datasets and is in charge of returning unpacked - train/eval data at each iteration - """ - - _target: Type = field(default_factory=lambda: DataManager) - """Target class to instantiate.""" - data: Optional[Path] = None - """Source of data, may not be used by all models.""" - masks_on_gpu: bool = False - """Process masks on GPU for speed at the expense of memory, if True.""" - images_on_gpu: bool = False - """Process images on GPU for speed at the expense of memory, if True.""" - diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index f29b38a4e0..731f214e77 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -298,7 +298,6 @@ def get_train_loss_dict(self, step: int): step: current iteration step to update sampler if using DDP (distributed) """ ray_bundle, batch = self.datamanager.next_train(step) - ray_bundle = ray_bundle.to(self.device) model_outputs = self._model(ray_bundle) # train distributed data parallel model if world_size > 1 metrics_dict = self.model.get_metrics_dict(model_outputs, batch) loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict) From 89f3d98cca4af726ca327b4ca56faa95b2f40e5a Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sat, 14 Sep 2024 18:33:18 -0700 Subject: [PATCH 52/78] cleaning up next_train --- nerfstudio/data/datamanagers/base_datamanager.py | 12 +++--------- nerfstudio/data/utils/dataloaders.py | 2 +- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 3a91cd3749..0d3c7e8ed7 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -546,19 +546,13 @@ def setup_eval(self): def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" self.train_count += 1 - ret = next(self.iter_train_raybundles) - assert len(ret) == 1, f"batch size should be one" - ray_bundle, batch = ret[0] - ray_bundle = ray_bundle.to(self.device) + ray_bundle, batch = next(self.iter_train_raybundles)[0] return ray_bundle, batch def next_eval(self, step: int) -> Tuple[RayBundle, Dict]: - """Returns the next batch of data from the train dataloader.""" + """Returns the next batch of data from the eval dataloader.""" self.eval_count += 1 - ret = next(self.iter_eval_raybundles) - assert len(ret) == 1, f"batch size should be one {len(ret)}" - ray_bundle, batch = ret[0] - ray_bundle = ray_bundle.to(self.device) + ray_bundle, batch = next(self.iter_train_raybundles)[0] return ray_bundle, batch def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 30ad214a89..0bd104cc03 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -40,7 +40,7 @@ def variable_res_collate(batch: List[Dict]) -> Dict: - """Default collate function for the cached dataloader. + """Default collate function for our dataloader. Args: batch: Batch of samples from the dataset. Returns: From 14e60e550f64a7911c6a3bd4a744799099303c0d Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 19 Sep 2024 14:54:48 -0700 Subject: [PATCH 53/78] replaced parallel datamanager with new datamanager --- .../data/datamanagers/parallel_datamanager.py | 251 +++++------------- 1 file changed, 73 insertions(+), 178 deletions(-) diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index 87c8f9f87a..f6c459e0b8 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -47,97 +47,29 @@ from nerfstudio.utils.rich_utils import CONSOLE -@dataclass -class ParallelDataManagerConfig(VanillaDataManagerConfig): - """Config for a `ParallelDataManager` which reads data in multiple processes""" - - _target: Type = field(default_factory=lambda: ParallelDataManager) - """Target class to instantiate.""" - num_processes: int = 1 - """Number of processes to use for train data loading. More than 1 doesn't result in that much better performance""" - queue_size: int = 2 - """Size of shared data queue containing generated ray bundles and batches. - If queue_size <= 0, the queue size is infinite.""" - max_thread_workers: Optional[int] = None - """Maximum number of threads to use in thread pool executor. If None, use ThreadPool default.""" - - -class DataProcessor(mp.Process): # type: ignore - """Parallel dataset batch processor. - - This class is responsible for generating ray bundles from an input dataset - in parallel python processes. - - Args: - out_queue: the output queue for storing the processed data - config: configuration object for the parallel data manager - dataparser_outputs: outputs from the dataparser - dataset: input dataset - pixel_sampler: The pixel sampler for sampling rays - """ - - def __init__( - self, - out_queue: mp.Queue, # type: ignore - config: ParallelDataManagerConfig, - dataparser_outputs: DataparserOutputs, - dataset: TDataset, - pixel_sampler: PixelSampler, - ): - super().__init__() - self.daemon = True - self.out_queue = out_queue - self.config = config - self.dataparser_outputs = dataparser_outputs - self.dataset = dataset - self.exclude_batch_keys_from_device = self.dataset.exclude_batch_keys_from_device - self.pixel_sampler = pixel_sampler - self.ray_generator = RayGenerator(self.dataset.cameras) - - def run(self): - """Append out queue in parallel with ray bundles and batches.""" - self.cache_images() - while True: - batch = self.pixel_sampler.sample(self.img_data) - ray_indices = batch["indices"] - ray_bundle: RayBundle = self.ray_generator(ray_indices) - # check that GPUs are available - if torch.cuda.is_available(): - ray_bundle = ray_bundle.pin_memory() - while True: - try: - self.out_queue.put((ray_bundle, batch)) - break - except queue.Full: - time.sleep(0.0001) - except Exception: - CONSOLE.print_exception() - CONSOLE.print("[bold red]Error: Error occurred in parallel datamanager queue.") - - def cache_images(self): - """Caches all input images into a NxHxWx3 tensor.""" - indices = range(len(self.dataset)) - batch_list = [] - results = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=self.config.max_thread_workers) as executor: - for idx in indices: - res = executor.submit(self.dataset.__getitem__, idx) - results.append(res) - for res in track(results, description="Loading data batch", transient=False): - batch_list.append(res.result()) - self.img_data = self.config.collate_fn(batch_list) - - class ParallelDataManager(DataManager, Generic[TDataset]): - """Data manager implementation for parallel dataloading. + """Basic stored data manager implementation. + + This is pretty much a port over from our old dataloading utilities, and is a little jank + under the hood. We may clean this up a little bit under the hood with more standard dataloading + components that can be strung together, but it can be just used as a black box for now since + only the constructor is likely to change in the future, or maybe passing in step number to the + next_train and next_eval functions. Args: config: the DataManagerConfig used to instantiate class """ + config: VanillaDataManagerConfig + train_dataset: TDataset + eval_dataset: TDataset + train_dataparser_outputs: DataparserOutputs + train_pixel_sampler: Optional[PixelSampler] = None + eval_pixel_sampler: Optional[PixelSampler] = None + def __init__( self, - config: ParallelDataManagerConfig, + config: VanillaDataManagerConfig, device: Union[torch.device, str] = "cpu", test_mode: Literal["test", "val", "inference"] = "val", world_size: int = 1, @@ -148,6 +80,7 @@ def __init__( self.device = device self.world_size = world_size self.local_rank = local_rank + self.sampler = None self.test_mode = test_mode self.test_split = "test" if test_mode in ["test", "inference"] else "val" self.dataparser_config = self.config.dataparser @@ -160,36 +93,38 @@ def __init__( self.dataparser.downscale_factor = 1 # Avoid opening images self.includes_time = self.dataparser.includes_time self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") - self.eval_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split=self.test_split) - cameras = self.train_dataparser_outputs.cameras - if len(cameras) > 1: - for i in range(1, len(cameras)): - if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height: - CONSOLE.print("Variable resolution, using variable_res_collate") - self.config.collate_fn = variable_res_collate - break + self.train_dataset = self.create_train_dataset() self.eval_dataset = self.create_eval_dataset() self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device - # Spawn is critical for not freezing the program (PyTorch compatability issue) - # check if spawn is already set - if mp.get_start_method(allow_none=True) is None: # type: ignore - mp.set_start_method("spawn") # type: ignore + if self.config.masks_on_gpu is True and "mask" in self.exclude_batch_keys_from_device: + self.exclude_batch_keys_from_device.remove("mask") + if self.config.images_on_gpu is True and "image" in self.exclude_batch_keys_from_device: + self.exclude_batch_keys_from_device.remove("image") + + if self.train_dataparser_outputs is not None: + cameras = self.train_dataparser_outputs.cameras + if len(cameras) > 1: + for i in range(1, len(cameras)): + if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height or True: # or True: # ADJUST COLLATE FN HERE + CONSOLE.print("Variable resolution, using variable_res_collate") + self.config.collate_fn = variable_res_collate + break super().__init__() @cached_property def dataset_type(self) -> Type[TDataset]: """Returns the dataset type passed as the generic argument""" default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore - orig_class: Type[ParallelDataManager] = get_orig_class(self, default=None) # type: ignore - if type(self) is ParallelDataManager and orig_class is None: + orig_class: Type[VanillaDataManager] = get_orig_class(self, default=None) # type: ignore + if type(self) is VanillaDataManager and orig_class is None: return default - if orig_class is not None and get_origin(orig_class) is ParallelDataManager: + if orig_class is not None and get_origin(orig_class) is VanillaDataManager: return get_args(orig_class)[0] # For inherited classes, we need to find the correct type to instantiate for base in getattr(self, "__orig_bases__", []): - if get_origin(base) is ParallelDataManager: + if get_origin(base) is VanillaDataManager: for value in get_args(base): if isinstance(value, ForwardRef): if value.__forward_evaluated__: @@ -203,126 +138,93 @@ def dataset_type(self) -> Type[TDataset]: return default def create_train_dataset(self) -> TDataset: - """Sets up the data loaders for training.""" + """Sets up the data loaders for training""" return self.dataset_type( dataparser_outputs=self.train_dataparser_outputs, scale_factor=self.config.camera_res_scale_factor, ) def create_eval_dataset(self) -> TDataset: - """Sets up the data loaders for evaluation.""" + """Sets up the data loaders for evaluation""" return self.dataset_type( dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), scale_factor=self.config.camera_res_scale_factor, ) - def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler: - """Infer pixel sampler to use.""" - if self.config.patch_size > 1 and type(self.config.pixel_sampler) is PixelSamplerConfig: - return PatchPixelSamplerConfig().setup( - patch_size=self.config.patch_size, num_rays_per_batch=num_rays_per_batch - ) - is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() - if is_equirectangular.any(): - CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") - - fisheye_crop_radius = None - if dataset.cameras.metadata is not None: - fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") - - return self.config.pixel_sampler.setup( - is_equirectangular=is_equirectangular, - num_rays_per_batch=num_rays_per_batch, - fisheye_crop_radius=fisheye_crop_radius, - ) def setup_train(self): - """Sets up parallel python data processes for training.""" - assert self.train_dataset is not None - self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) # type: ignore - self.data_queue = mp.Queue(maxsize=self.config.queue_size) # type: ignore - self.data_procs = [ - DataProcessor( - out_queue=self.data_queue, # type: ignore - config=self.config, - dataparser_outputs=self.train_dataparser_outputs, - dataset=self.train_dataset, - pixel_sampler=self.train_pixel_sampler, - ) - for i in range(self.config.num_processes) - ] - for proc in self.data_procs: - proc.start() - print("Started threads") - - def setup_eval(self): - """Sets up the data loader for evaluation.""" - assert self.eval_dataset is not None - CONSOLE.print("Setting up evaluation dataset...") - self.eval_image_dataloader = CacheDataloader( - self.eval_dataset, - num_images_to_sample_from=self.config.eval_num_images_to_sample_from, - num_times_to_repeat_images=self.config.eval_num_times_to_repeat_images, + self.train_raybatchstream = RayBatchStream( + input_dataset=self.train_dataset, + num_rays_per_batch=self.config.train_num_rays_per_batch, + num_images_to_sample_from=self.config.train_num_images_to_sample_from, + num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - num_workers=self.world_size * 4, - pin_memory=True, - collate_fn=self.config.collate_fn, - exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, + collate_fn = variable_res_collate, ) - self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) - self.eval_pixel_sampler = self._get_pixel_sampler(self.eval_dataset, self.config.eval_num_rays_per_batch) # type: ignore - self.eval_ray_generator = RayGenerator(self.eval_dataset.cameras.to(self.device)) - # for loading full images - self.fixed_indices_eval_dataloader = FixedIndicesEvalDataloader( + self.train_ray_dataloader = torch.utils.data.DataLoader( + self.train_raybatchstream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + prefetch_factor=self.config.prefetch_factor, + shuffle=False, + collate_fn=identity_collate, # Our dataset handles batching / collation of rays + ) + self.iter_train_raybundles = iter(self.train_ray_dataloader) + + def setup_eval(self): + self.eval_raybatchstream = RayBatchStream( input_dataset=self.eval_dataset, + num_rays_per_batch=self.config.train_num_rays_per_batch, + num_images_to_sample_from=self.config.train_num_images_to_sample_from, + num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - num_workers=self.world_size * 4, + collate_fn = variable_res_collate, + load_from_disk = True, + ) + self.eval_ray_dataloader = torch.utils.data.DataLoader( + self.eval_raybatchstream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + prefetch_factor=self.config.prefetch_factor, + shuffle=False, + collate_fn=identity_collate, # Our dataset handles batching / collation of rays ) - self.eval_dataloader = RandIndicesEvalDataloader( + self.iter_eval_raybundles = iter(self.eval_ray_dataloader) + self.image_eval_dataloader = RandIndicesEvalDataloader( input_dataset=self.eval_dataset, device=self.device, num_workers=self.world_size * 4, ) def next_train(self, step: int) -> Tuple[RayBundle, Dict]: - """Returns the next batch of data from the parallel training processes.""" + """Returns the next batch of data from the train dataloader.""" self.train_count += 1 - bundle, batch = self.data_queue.get() - ray_bundle = bundle.to(self.device) + ray_bundle, batch = next(self.iter_train_raybundles)[0] return ray_bundle, batch def next_eval(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the eval dataloader.""" self.eval_count += 1 - image_batch = next(self.iter_eval_image_dataloader) - assert self.eval_pixel_sampler is not None - assert isinstance(image_batch, dict) - batch = self.eval_pixel_sampler.sample(image_batch) - ray_indices = batch["indices"] - ray_bundle = self.eval_ray_generator(ray_indices) + ray_bundle, batch = next(self.iter_train_raybundles)[0] return ray_bundle, batch def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: - """Retrieve the next eval image.""" - for camera, batch in self.eval_dataloader: + for camera, batch in self.image_eval_dataloader: assert camera.shape[0] == 1 return camera, batch raise ValueError("No more eval images") def get_train_rays_per_batch(self) -> int: - """Returns the number of rays per batch for training.""" if self.train_pixel_sampler is not None: return self.train_pixel_sampler.num_rays_per_batch return self.config.train_num_rays_per_batch def get_eval_rays_per_batch(self) -> int: - """Returns the number of rays per batch for evaluation.""" if self.eval_pixel_sampler is not None: return self.eval_pixel_sampler.num_rays_per_batch return self.config.eval_num_rays_per_batch def get_datapath(self) -> Path: - """Returns the path to the data. This is used to determine where to save camera paths.""" return self.config.dataparser.data def get_param_groups(self) -> Dict[str, List[Parameter]]: @@ -331,10 +233,3 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: A list of dictionaries containing the data manager's param groups. """ return {} - - def __del__(self): - """Clean up the parallel data processes.""" - if hasattr(self, "data_procs"): - for proc in self.data_procs: - proc.terminate() - proc.join() From 204dfb2046d810888ca0853708823e3f542d94ae Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 19 Sep 2024 15:06:03 -0700 Subject: [PATCH 54/78] reverted the original base_datamanager.py, new datamanager replaced parallel_datamanager.py --- .../data/datamanagers/base_datamanager.py | 150 +++++++++--------- .../data/datamanagers/parallel_datamanager.py | 2 +- 2 files changed, 73 insertions(+), 79 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 0d3c7e8ed7..cff03607b9 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -39,7 +39,7 @@ get_args, get_origin, ) -import time + import torch import tyro from torch import nn @@ -56,14 +56,8 @@ from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import ( - # CacheDataloader, - RayBatchStream, - FixedIndicesEvalDataloader, - RandIndicesEvalDataloader, -) +from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate -from nerfstudio.data.utils.data_utils import identity_collate from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import IterableWrapper, get_orig_class @@ -92,6 +86,7 @@ def variable_res_collate(batch: List[Dict]) -> Dict: # now that iteration is complete, the image data items can be removed from the batch for key in topop: del data[key] + new_batch = nerfstudio_collate(batch) new_batch["image"] = images new_batch.update(imgdata_lists) @@ -183,7 +178,6 @@ def __init__(self): self.train_count = 0 self.eval_count = 0 if self.train_dataset and self.test_mode != "inference": - # print(self.setup_train) # prints self.setup_train() if self.eval_dataset and self.test_mode != "inference": self.setup_eval() @@ -317,8 +311,6 @@ class VanillaDataManagerConfig(DataManagerConfig): """Target class to instantiate.""" dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig) """Specifies the dataparser used to unpack the data.""" - cache_images_type: Literal["uint8", "float32"] = "float32" - """The image type returned from manager, caching images in uint8 saves memory""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" train_num_images_to_sample_from: int = -1 @@ -339,20 +331,10 @@ class VanillaDataManagerConfig(DataManagerConfig): """Specifies the collate function to use for the train and eval dataloaders.""" camera_res_scale_factor: float = 1.0 """The scale factor for scaling spatial data such as images, mask, semantics - along with relevant information about camera intrinsics""" + along with relevant information about camera intrinsics + """ patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - use_parallel_dataloader: bool = True - """Allows parallelization of the dataloading process with multiple workers prefetching RayBundles.""" - load_from_disk: bool = False - """If True, conserves RAM memory by loading images from disk. - If False, caches all the images as tensors to RAM and loads from RAM.""" - dataloader_num_workers: int = 0 - """The number of workers performing the dataloading from either disk/RAM, which - includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = None - """The limit number of batches a worker will start loading once an iterator is created. - More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -369,23 +351,7 @@ def __post_init__(self): "\nCameraOptimizerConfig has been moved from the DataManager to the Model.\n", style="bold yellow" ) warnings.warn("above message coming from", FutureWarning, stacklevel=3) - - """ - These heuristics allow the CPU dataloading bottleneck to equal the GPU bottleneck when training, but can be adjusted - Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck. - """ - if self.load_from_disk: - self.train_num_images_to_sample_from = 50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from - self.train_num_times_to_repeat_images = 10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images - self.prefetch_factor = self.train_num_times_to_repeat_images if self.use_parallel_dataloader else None - - if self.use_parallel_dataloader: - try: - torch.multiprocessing.set_start_method("spawn") - except RuntimeError: - pass - self.dataloader_num_workers = 4 - + TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) @@ -449,7 +415,7 @@ def __init__( cameras = self.train_dataparser_outputs.cameras if len(cameras) > 1: for i in range(1, len(cameras)): - if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height or True: # or True: # ADJUST COLLATE FN HERE + if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height: CONSOLE.print("Variable resolution, using variable_res_collate") self.config.collate_fn = variable_res_collate break @@ -494,50 +460,68 @@ def create_eval_dataset(self) -> TDataset: scale_factor=self.config.camera_res_scale_factor, ) + def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler: + """Infer pixel sampler to use.""" + if self.config.patch_size > 1 and type(self.config.pixel_sampler) is PixelSamplerConfig: + return PatchPixelSamplerConfig().setup( + patch_size=self.config.patch_size, num_rays_per_batch=num_rays_per_batch + ) + is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() + if is_equirectangular.any(): + CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") + + fisheye_crop_radius = None + if dataset.cameras.metadata is not None: + fisheye_crop_radius = dataset.cameras.metadata.get("fisheye_crop_radius") + + return self.config.pixel_sampler.setup( + is_equirectangular=is_equirectangular, + num_rays_per_batch=num_rays_per_batch, + fisheye_crop_radius=fisheye_crop_radius, + ) def setup_train(self): - self.train_raybatchstream = RayBatchStream( - input_dataset=self.train_dataset, - num_rays_per_batch=self.config.train_num_rays_per_batch, + """Sets up the data loaders for training""" + assert self.train_dataset is not None + CONSOLE.print("Setting up training dataset...") + self.train_image_dataloader = CacheDataloader( + self.train_dataset, num_images_to_sample_from=self.config.train_num_images_to_sample_from, num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - collate_fn = variable_res_collate, - load_from_disk = True, - ) - self.train_ray_dataloader = torch.utils.data.DataLoader( - self.train_raybatchstream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - prefetch_factor=self.config.prefetch_factor, - shuffle=False, - # pin_memory=True, - collate_fn=identity_collate, # Our dataset handles batching / collation of rays - pin_memory_device=self.device, + num_workers=self.world_size * 4, + pin_memory=True, + collate_fn=self.config.collate_fn, + exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, ) - self.iter_train_raybundles = iter(self.train_ray_dataloader) - + self.iter_train_image_dataloader = iter(self.train_image_dataloader) + self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) + self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device)) + def setup_eval(self): - self.eval_raybatchstream = RayBatchStream( - input_dataset=self.eval_dataset, - num_rays_per_batch=self.config.train_num_rays_per_batch, - num_images_to_sample_from=self.config.train_num_images_to_sample_from, - num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, + """Sets up the data loader for evaluation""" + assert self.eval_dataset is not None + CONSOLE.print("Setting up evaluation dataset...") + self.eval_image_dataloader = CacheDataloader( + self.eval_dataset, + num_images_to_sample_from=self.config.eval_num_images_to_sample_from, + num_times_to_repeat_images=self.config.eval_num_times_to_repeat_images, device=self.device, - collate_fn = variable_res_collate, - load_from_disk = True, + num_workers=self.world_size * 4, + pin_memory=True, + collate_fn=self.config.collate_fn, + exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, ) - self.eval_ray_dataloader = torch.utils.data.DataLoader( - self.eval_raybatchstream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - prefetch_factor=self.config.prefetch_factor, - shuffle=False, - collate_fn=identity_collate, # Our dataset handles batching / collation of rays - pin_memory_device=self.device, + self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) + self.eval_pixel_sampler = self._get_pixel_sampler(self.eval_dataset, self.config.eval_num_rays_per_batch) + self.eval_ray_generator = RayGenerator(self.eval_dataset.cameras.to(self.device)) + # for loading full images + self.fixed_indices_eval_dataloader = FixedIndicesEvalDataloader( + input_dataset=self.eval_dataset, + device=self.device, + num_workers=self.world_size * 4, ) - self.iter_eval_raybundles = iter(self.eval_ray_dataloader) - self.image_eval_dataloader = RandIndicesEvalDataloader( + self.eval_dataloader = RandIndicesEvalDataloader( input_dataset=self.eval_dataset, device=self.device, num_workers=self.world_size * 4, @@ -546,17 +530,27 @@ def setup_eval(self): def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" self.train_count += 1 - ray_bundle, batch = next(self.iter_train_raybundles)[0] + image_batch = next(self.iter_train_image_dataloader) + assert self.train_pixel_sampler is not None + assert isinstance(image_batch, dict) + batch = self.train_pixel_sampler.sample(image_batch) + ray_indices = batch["indices"] + ray_bundle = self.train_ray_generator(ray_indices) return ray_bundle, batch def next_eval(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the eval dataloader.""" self.eval_count += 1 - ray_bundle, batch = next(self.iter_train_raybundles)[0] + image_batch = next(self.iter_eval_image_dataloader) + assert self.eval_pixel_sampler is not None + assert isinstance(image_batch, dict) + batch = self.eval_pixel_sampler.sample(image_batch) + ray_indices = batch["indices"] + ray_bundle = self.eval_ray_generator(ray_indices) return ray_bundle, batch def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: - for camera, batch in self.image_eval_dataloader: + for camera, batch in self.eval_dataloader: assert camera.shape[0] == 1 return camera, batch raise ValueError("No more eval images") diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index f6c459e0b8..0e0feeab49 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -41,7 +41,7 @@ from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import FixedIndicesEvalDataloader, RandIndicesEvalDataloader #,CacheDataloader +from nerfstudio.data.utils.dataloaders import FixedIndicesEvalDataloader, RandIndicesEvalDataloader from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE From 5864bc920847a8fa3633e4e8e951eef9ec3a12c2 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 19 Sep 2024 15:09:05 -0700 Subject: [PATCH 55/78] modified VanillaConfig, but VanillaDataManager is the same as before --- .../data/datamanagers/base_datamanager.py | 46 ++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index cff03607b9..db0d9ca332 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -39,7 +39,7 @@ get_args, get_origin, ) - +import time import torch import tyro from torch import nn @@ -56,8 +56,14 @@ from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader +from nerfstudio.data.utils.dataloaders import ( + # CacheDataloader, + RayBatchStream, + FixedIndicesEvalDataloader, + RandIndicesEvalDataloader, +) from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate +from nerfstudio.data.utils.data_utils import identity_collate from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import IterableWrapper, get_orig_class @@ -86,7 +92,6 @@ def variable_res_collate(batch: List[Dict]) -> Dict: # now that iteration is complete, the image data items can be removed from the batch for key in topop: del data[key] - new_batch = nerfstudio_collate(batch) new_batch["image"] = images new_batch.update(imgdata_lists) @@ -178,6 +183,7 @@ def __init__(self): self.train_count = 0 self.eval_count = 0 if self.train_dataset and self.test_mode != "inference": + # print(self.setup_train) # prints self.setup_train() if self.eval_dataset and self.test_mode != "inference": self.setup_eval() @@ -311,6 +317,8 @@ class VanillaDataManagerConfig(DataManagerConfig): """Target class to instantiate.""" dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig) """Specifies the dataparser used to unpack the data.""" + cache_images_type: Literal["uint8", "float32"] = "float32" + """The image type returned from manager, caching images in uint8 saves memory""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" train_num_images_to_sample_from: int = -1 @@ -331,10 +339,20 @@ class VanillaDataManagerConfig(DataManagerConfig): """Specifies the collate function to use for the train and eval dataloaders.""" camera_res_scale_factor: float = 1.0 """The scale factor for scaling spatial data such as images, mask, semantics - along with relevant information about camera intrinsics - """ + along with relevant information about camera intrinsics""" patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" + use_parallel_dataloader: bool = True + """Allows parallelization of the dataloading process with multiple workers prefetching RayBundles.""" + load_from_disk: bool = False + """If True, conserves RAM memory by loading images from disk. + If False, caches all the images as tensors to RAM and loads from RAM.""" + dataloader_num_workers: int = 0 + """The number of workers performing the dataloading from either disk/RAM, which + includes collating, pixel sampling, unprojecting, ray generation etc.""" + prefetch_factor: int = None + """The limit number of batches a worker will start loading once an iterator is created. + More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -351,7 +369,23 @@ def __post_init__(self): "\nCameraOptimizerConfig has been moved from the DataManager to the Model.\n", style="bold yellow" ) warnings.warn("above message coming from", FutureWarning, stacklevel=3) - + + """ + These heuristics allow the CPU dataloading bottleneck to equal the GPU bottleneck when training, but can be adjusted + Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck. + """ + if self.load_from_disk: + self.train_num_images_to_sample_from = 50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from + self.train_num_times_to_repeat_images = 10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images + self.prefetch_factor = self.train_num_times_to_repeat_images if self.use_parallel_dataloader else None + + if self.use_parallel_dataloader: + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + pass + self.dataloader_num_workers = 4 + TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) From 6d97de311b5c1f2e8dea6a826c88c125f5db5759 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 19 Sep 2024 15:45:54 -0700 Subject: [PATCH 56/78] cleaning up, 2 datamanagers now - original and new parallel one --- nerfstudio/configs/method_configs.py | 2 ++ nerfstudio/data/datamanagers/base_datamanager.py | 3 +-- nerfstudio/data/datamanagers/parallel_datamanager.py | 11 ++--------- nerfstudio/data/utils/dataloaders.py | 3 ++- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index e9f98b274a..7b66971af3 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -92,10 +92,12 @@ mixed_precision=True, pipeline=VanillaPipelineConfig( datamanager=VanillaDataManagerConfig( + _target=ParallelDatamanager[InputDataset], dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=4096, eval_num_rays_per_batch=4096, load_from_disk=True, + use_parallel_dataloader=True, ), model=NerfactoModelConfig( eval_num_rays_per_chunk=1 << 15, diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index db0d9ca332..22e97399ba 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -183,7 +183,6 @@ def __init__(self): self.train_count = 0 self.eval_count = 0 if self.train_dataset and self.test_mode != "inference": - # print(self.setup_train) # prints self.setup_train() if self.eval_dataset and self.test_mode != "inference": self.setup_eval() @@ -384,7 +383,7 @@ def __post_init__(self): torch.multiprocessing.set_start_method("spawn") except RuntimeError: pass - self.dataloader_num_workers = 4 + self.dataloader_num_workers = 4 if self.dataloader_num_workers == 0 else self.dataloader_num_workers TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index 0e0feeab49..d34bcb3cbf 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -48,13 +48,7 @@ class ParallelDataManager(DataManager, Generic[TDataset]): - """Basic stored data manager implementation. - - This is pretty much a port over from our old dataloading utilities, and is a little jank - under the hood. We may clean this up a little bit under the hood with more standard dataloading - components that can be strung together, but it can be just used as a black box for now since - only the constructor is likely to change in the future, or maybe passing in step number to the - next_train and next_eval functions. + """Data manager implementation for parallel dataloading Args: config: the DataManagerConfig used to instantiate class @@ -80,7 +74,6 @@ def __init__( self.device = device self.world_size = world_size self.local_rank = local_rank - self.sampler = None self.test_mode = test_mode self.test_split = "test" if test_mode in ["test", "inference"] else "val" self.dataparser_config = self.config.dataparser @@ -106,7 +99,7 @@ def __init__( cameras = self.train_dataparser_outputs.cameras if len(cameras) > 1: for i in range(1, len(cameras)): - if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height or True: # or True: # ADJUST COLLATE FN HERE + if cameras[0].width != cameras[i].width or cameras[0].height != cameras[i].height or True: CONSOLE.print("Variable resolution, using variable_res_collate") self.config.collate_fn = variable_res_collate break diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 0bd104cc03..c35830e5cf 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -292,6 +292,7 @@ def _get_batch_list(self, indices=None): for idx in indices: res = executor.submit(self.input_dataset.__getitem__, idx) results.append(res) + results = tqdm(results) # this is temporary and will be removed in the final push for res in results: batch_list.append(res.result()) @@ -333,7 +334,7 @@ def __iter__(self): num_rays_per_loop = self.num_rays_per_batch # default train_num_rays_per_batch is 4096 # each worker has its own pixel sampler worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) - self.ray_generator = RayGenerator(self.input_dataset.cameras) + self.ray_generator = RayGenerator(self.input_dataset.cameras) # the generated RayBundles will be on the same device as self.input_dataset.cameras (CPU) i = 0 while True: From 1f3401795d182831d141363549dc08fc45aa8e70 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Thu, 19 Sep 2024 16:05:41 -0700 Subject: [PATCH 57/78] able to train with new nerfstudio dataloader now --- nerfstudio/configs/method_configs.py | 10 +++++----- nerfstudio/data/datamanagers/parallel_datamanager.py | 5 ++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 7b66971af3..fb59e5c760 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -28,7 +28,7 @@ from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig -from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig +from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig @@ -92,7 +92,7 @@ mixed_precision=True, pipeline=VanillaPipelineConfig( datamanager=VanillaDataManagerConfig( - _target=ParallelDatamanager[InputDataset], + _target=ParallelDataManager[InputDataset], dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=4096, eval_num_rays_per_batch=4096, @@ -174,7 +174,7 @@ max_num_iterations=100000, mixed_precision=True, pipeline=VanillaPipelineConfig( - datamanager=ParallelDataManagerConfig( + datamanager=VanillaDataManagerConfig( dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=16384, eval_num_rays_per_batch=4096, @@ -308,7 +308,7 @@ method_configs["mipnerf"] = TrainerConfig( method_name="mipnerf", pipeline=VanillaPipelineConfig( - datamanager=ParallelDataManagerConfig(dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=1024), + datamanager=VanillaDataManagerConfig(dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=1024), model=VanillaModelConfig( _target=MipNerfModel, loss_coefficients={"rgb_loss_coarse": 0.1, "rgb_loss_fine": 1.0}, @@ -381,7 +381,7 @@ max_num_iterations=30000, mixed_precision=False, pipeline=VanillaPipelineConfig( - datamanager=ParallelDataManagerConfig( + datamanager=VanillaDataManagerConfig( dataparser=BlenderDataParserConfig(), train_num_rays_per_batch=4096, eval_num_rays_per_batch=4096, diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index d34bcb3cbf..13e9da7ec4 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -36,12 +36,14 @@ DataManager, TDataset, VanillaDataManagerConfig, + VanillaDataManager, variable_res_collate, ) from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import FixedIndicesEvalDataloader, RandIndicesEvalDataloader +from nerfstudio.data.utils.dataloaders import RayBatchStream, FixedIndicesEvalDataloader, RandIndicesEvalDataloader +from nerfstudio.data.utils.data_utils import identity_collate from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -153,6 +155,7 @@ def setup_train(self): num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, collate_fn = variable_res_collate, + load_from_disk = self.config.load_from_disk, ) self.train_ray_dataloader = torch.utils.data.DataLoader( self.train_raybatchstream, From 99cf86a9a4b837ebf4e0fe04920684035cbbd87d Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 22 Sep 2024 18:28:25 -0700 Subject: [PATCH 58/78] side by side datamanagers, moved tons of logic into dataloaders.py and created new files for our parallel datamangers --- nerfstudio/configs/method_configs.py | 8 +- .../data/datamanagers/base_datamanager.py | 6 +- .../datamanagers/full_images_datamanager.py | 124 ++-- .../data/datamanagers/parallel_datamanager.py | 38 +- .../parallel_full_images_datamanager.py | 237 ++++++++ nerfstudio/data/datasets/base_dataset.py | 12 +- nerfstudio/data/utils/dataloaders.py | 549 +++++++++++++----- nerfstudio/pipelines/base_pipeline.py | 4 + 8 files changed, 752 insertions(+), 226 deletions(-) create mode 100644 nerfstudio/data/datamanagers/parallel_full_images_datamanager.py diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index fb59e5c760..426132f1e5 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -29,6 +29,7 @@ from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager +from nerfstudio.data.datamanagers.parallel_full_images_datamanager import ParallelFullImageDatamanager from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig @@ -605,11 +606,12 @@ mixed_precision=False, pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( - # _target=ParallelFullImageDatamanager[InputDataset], - dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), - # dataparser=NerfstudioDataParserConfig(load_3D_points=True), + _target=ParallelFullImageDatamanager[InputDataset], + # dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), + dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", use_parallel_dataloader=True, + cache_images="disk", ), model=SplatfactoModelConfig(), ), diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 22e97399ba..330593e19d 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -341,7 +341,7 @@ class VanillaDataManagerConfig(DataManagerConfig): along with relevant information about camera intrinsics""" patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - use_parallel_dataloader: bool = True + use_parallel_dataloader: bool = False """Allows parallelization of the dataloading process with multiple workers prefetching RayBundles.""" load_from_disk: bool = False """If True, conserves RAM memory by loading images from disk. @@ -352,6 +352,8 @@ class VanillaDataManagerConfig(DataManagerConfig): prefetch_factor: int = None """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" + cache_compressed_images: bool = False + """If True, cache raw image files as byte strings to RAM.""" # tyro.conf.Suppress prevents us from creating CLI arguments for this field. camera_optimizer: tyro.conf.Suppress[Optional[CameraOptimizerConfig]] = field(default=None) @@ -484,6 +486,7 @@ def create_train_dataset(self) -> TDataset: return self.dataset_type( dataparser_outputs=self.train_dataparser_outputs, scale_factor=self.config.camera_res_scale_factor, + cache_compressed_images=self.config.cache_compressed_images, ) def create_eval_dataset(self) -> TDataset: @@ -491,6 +494,7 @@ def create_eval_dataset(self) -> TDataset: return self.dataset_type( dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), scale_factor=self.config.camera_res_scale_factor, + cache_compressed_images=self.config.cache_compressed_images, ) def _get_pixel_sampler(self, dataset: TDataset, num_rays_per_batch: int) -> PixelSampler: diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 751acfadbf..46d83e8311 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -75,12 +75,29 @@ class FullImageDatamanagerConfig(DataManagerConfig): samples from the pool of all training cameras without replacement before a new round of sampling starts.""" use_parallel_dataloader: bool = cache_images == "disk" """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" - dataloader_num_workers: int = 4 + load_from_disk: bool = False + """If True, conserves RAM memory by loading images from disk. + If False, caches all the images as tensors to RAM and loads from RAM.""" + dataloader_num_workers: int = 0 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = 2 + prefetch_factor: int = None """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" + cache_compressed_images: bool = False + """If True, cache raw image files as byte strings to RAM.""" + + def __post_init__(self): + if self.load_from_disk: + self.prefetch_factor = 2 if self.use_parallel_dataloader else None + + if self.use_parallel_dataloader: + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + pass + self.dataloader_num_workers = 4 if self.dataloader_num_workers == 0 else self.dataloader_num_workers + class FullImageDatamanager(DataManager, Generic[TDataset]): """ @@ -140,10 +157,6 @@ def __init__( self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] assert len(self.train_unseen_cameras) > 0, "No data found in dataset" - if self.config.use_parallel_dataloader: - import torch.multiprocessing as mp - mp.set_start_method("spawn") - super().__init__() def sample_train_cameras(self): @@ -206,6 +219,7 @@ def _load_images( dataset = self.eval_dataset else: assert_never(split) + def undistort_idx(idx: int) -> Dict[str, torch.Tensor]: data = dataset.get_data(idx, image_type=self.config.cache_images_type) camera = dataset.cameras[idx].reshape(()) @@ -321,10 +335,10 @@ def setup_train(self): batch_size=1, num_workers=self.config.dataloader_num_workers, collate_fn=identity_collate, - # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? + # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? ) self.iter_train_image_dataloader = iter(self.train_image_dataloader) - + def setup_eval(self): """Sets up the data loader for evaluation""" if self.config.use_parallel_dataloader: @@ -386,7 +400,7 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: data = self.cached_train[image_idx] data["image"] = data["image"].to(self.device) - assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension" + assert lDuden(self.train_cameras.shape) == 1, "Assumes single batch dimension" camera = self.train_cameras[image_idx : image_idx + 1].to(self.device) if camera.metadata is None: camera.metadata = {} @@ -430,7 +444,7 @@ def _undistort_image( "We don't support the 4th Brown parameter for image undistortion, " "Only k1, k2, k3, p1, p2 can be non-zero." ) - # we rearrange the distortion parameters because OpenCV expects the order (k1, k2, p1, p2, k3) + # we rearrange the distortion parameters because OpenCV expects the order (k1, k2, p1, p2, k3) # see https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html distortion_params = np.array( [ @@ -599,41 +613,43 @@ def _undistort_image( return K, image, mask -def undistort_view(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: - """Undistorts an image to one taken by a linear (pinhole) camera model and returns a new Camera with these updated intrinsics - Note: this method does not modify the dataset's attributes at all. - - Returns: The undistorted data (image, depth, mask, etc.) and the new linear Camera object - """ - data = dataset.get_data(idx, image_type) - camera = dataset.cameras[idx].reshape(()) - assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( - f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' - f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' - ) - if camera.distortion_params is None or torch.all(camera.distortion_params == 0): - return data - K = camera.get_intrinsics_matrices().numpy() - distortion_params = camera.distortion_params.numpy() - image = data["image"].numpy() - K, image, mask = _undistort_image(camera, distortion_params, data, image, K) - data["image"] = torch.from_numpy(image) - if mask is not None: - data["mask"] = mask - - # create a new Camera with the rectified / undistorted intrinsics - new_camera = Cameras( - camera_to_worlds=camera.camera_to_worlds.unsqueeze(0), - fx=torch.Tensor([[float(K[0, 0])]]), - fy=torch.Tensor([[float(K[1, 1])]]), - cx=torch.Tensor([[float(K[0, 2])]]), - cy=torch.Tensor([[float(K[1, 2])]]), - width=torch.Tensor([[image.shape[1]]]).to(torch.int32), - height=torch.Tensor([[image.shape[0]]]).to(torch.int32), - ) - return data, new_camera +# def undistort_view(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: +# """Undistorts an image to one taken by a linear (pinhole) camera model and returns a new Camera with these updated intrinsics +# Note: this method does not modify the dataset's attributes at all. + +# Returns: The undistorted data (image, depth, mask, etc.) and the new linear Camera object +# """ +# data = dataset.get_data(idx, image_type) +# camera = dataset.cameras[idx].reshape(()) +# assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( +# f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' +# f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' +# ) +# if camera.distortion_params is None or torch.all(camera.distortion_params == 0): +# return data +# K = camera.get_intrinsics_matrices().numpy() +# distortion_params = camera.distortion_params.numpy() +# image = data["image"].numpy() +# K, image, mask = _undistort_image(camera, distortion_params, data, image, K) +# data["image"] = torch.from_numpy(image) +# if mask is not None: +# data["mask"] = mask + +# # create a new Camera with the rectified / undistorted intrinsics +# new_camera = Cameras( +# camera_to_worlds=camera.camera_to_worlds.unsqueeze(0), +# fx=torch.Tensor([[float(K[0, 0])]]), +# fy=torch.Tensor([[float(K[1, 1])]]), +# cx=torch.Tensor([[float(K[0, 2])]]), +# cy=torch.Tensor([[float(K[1, 2])]]), +# width=torch.Tensor([[image.shape[1]]]).to(torch.int32), +# height=torch.Tensor([[image.shape[0]]]).to(torch.int32), +# ) +# return data, new_camera import math + + class ImageBatchStream(torch.utils.data.IterableDataset): """ A wrapper of InputDataset that outputs undistorted full images and cameras. This makes the @@ -653,14 +669,12 @@ def __init__( def __iter__(self): # print(self.input_dataset.cameras.device) prints cpu - dataset_indices = list( - range(len(self.input_dataset)) - ) + dataset_indices = list(range(len(self.input_dataset))) worker_info = torch.utils.data.get_worker_info() if worker_info is not None: # if we have multiple processes per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) slice_start = worker_info.id * per_worker - else: # we only have a single process + else: # we only have a single process per_worker = len(self.input_dataset) slice_start = 0 worker_indices = dataset_indices[ @@ -668,12 +682,14 @@ def __iter__(self): ] # the indices of the datapoints in the dataset this worker will load r = random.Random(self.config.train_cameras_sampling_seed) r.shuffle(worker_indices) - i = 0 # i refers to what image index we are outputting: i=0 => we are yielding our first image,camera + i = 0 # i refers to what image index we are outputting: i=0 => we are yielding our first image,camera while True: - if i >= len(worker_indices): # if we've iterated through all the worker's partition of images, we need to reshuffle + if i >= len( + worker_indices + ): # if we've iterated through all the worker's partition of images, we need to reshuffle r.shuffle(worker_indices) i = 0 - idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve + idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve data, camera = undistort_view(idx, self.input_dataset, self.config.cache_images_type) if camera.metadata is None: camera.metadata = {} @@ -681,6 +697,7 @@ def __iter__(self): i += 1 yield camera, data + # class ParallelFullImageDatamanager(FullImageDatamanager, Generic[TDataset]): # def __init__( # self, @@ -701,7 +718,7 @@ def __iter__(self): # local_rank=local_rank, # **kwargs # ) - + # def setup_train(self): # self.train_imagebatch_stream = ImageBatchStream( # input_dataset=self.train_dataset, @@ -713,7 +730,7 @@ def __iter__(self): # batch_size=1, # num_workers=self.config.dataloader_num_workers, # collate_fn=identity_collate, -# # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? +# # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? # ) # self.iter_train_image_dataloader = iter(self.train_image_dataloader) @@ -740,9 +757,8 @@ def __iter__(self): # self.train_count += 1 # camera, data = next(self.iter_train_image_dataloader)[0] # return camera, data - + # def next_eval(self, step: int) -> Tuple[Cameras, Dict]: # self.eval_count += 1 # camera, data = next(self.iter_train_image_dataloader)[0] # return camera, data - \ No newline at end of file diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index 13e9da7ec4..29512af124 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -15,36 +15,29 @@ """ Parallel data manager that generates training data in multiple python processes. """ + from __future__ import annotations -import concurrent.futures -import queue -import time -from dataclasses import dataclass, field from functools import cached_property from pathlib import Path from typing import Dict, ForwardRef, Generic, List, Literal, Optional, Tuple, Type, Union, cast, get_args, get_origin import torch -from pathos.helpers import mp -from rich.progress import track from torch.nn import Parameter -from nerfstudio.cameras.cameras import Cameras, CameraType +from nerfstudio.cameras.cameras import Cameras from nerfstudio.cameras.rays import RayBundle from nerfstudio.data.datamanagers.base_datamanager import ( DataManager, TDataset, VanillaDataManagerConfig, - VanillaDataManager, variable_res_collate, ) from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs from nerfstudio.data.datasets.base_dataset import InputDataset -from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import RayBatchStream, FixedIndicesEvalDataloader, RandIndicesEvalDataloader +from nerfstudio.data.pixel_samplers import PixelSampler from nerfstudio.data.utils.data_utils import identity_collate -from nerfstudio.model_components.ray_generators import RayGenerator +from nerfstudio.data.utils.dataloaders import RandIndicesEvalDataloader, RayBatchStream from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -111,15 +104,15 @@ def __init__( def dataset_type(self) -> Type[TDataset]: """Returns the dataset type passed as the generic argument""" default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore - orig_class: Type[VanillaDataManager] = get_orig_class(self, default=None) # type: ignore - if type(self) is VanillaDataManager and orig_class is None: + orig_class: Type[ParallelDataManager] = get_orig_class(self, default=None) # type: ignore + if type(self) is ParallelDataManager and orig_class is None: return default - if orig_class is not None and get_origin(orig_class) is VanillaDataManager: + if orig_class is not None and get_origin(orig_class) is ParallelDataManager: return get_args(orig_class)[0] # For inherited classes, we need to find the correct type to instantiate for base in getattr(self, "__orig_bases__", []): - if get_origin(base) is VanillaDataManager: + if get_origin(base) is ParallelDataManager: for value in get_args(base): if isinstance(value, ForwardRef): if value.__forward_evaluated__: @@ -146,7 +139,6 @@ def create_eval_dataset(self) -> TDataset: scale_factor=self.config.camera_res_scale_factor, ) - def setup_train(self): self.train_raybatchstream = RayBatchStream( input_dataset=self.train_dataset, @@ -154,8 +146,8 @@ def setup_train(self): num_images_to_sample_from=self.config.train_num_images_to_sample_from, num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - collate_fn = variable_res_collate, - load_from_disk = self.config.load_from_disk, + collate_fn=variable_res_collate, + load_from_disk=self.config.load_from_disk, ) self.train_ray_dataloader = torch.utils.data.DataLoader( self.train_raybatchstream, @@ -163,10 +155,10 @@ def setup_train(self): num_workers=self.config.dataloader_num_workers, prefetch_factor=self.config.prefetch_factor, shuffle=False, - collate_fn=identity_collate, # Our dataset handles batching / collation of rays + collate_fn=identity_collate, # Our dataset handles batching / collation of rays ) self.iter_train_raybundles = iter(self.train_ray_dataloader) - + def setup_eval(self): self.eval_raybatchstream = RayBatchStream( input_dataset=self.eval_dataset, @@ -174,8 +166,8 @@ def setup_eval(self): num_images_to_sample_from=self.config.train_num_images_to_sample_from, num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, device=self.device, - collate_fn = variable_res_collate, - load_from_disk = True, + collate_fn=variable_res_collate, + load_from_disk=True, ) self.eval_ray_dataloader = torch.utils.data.DataLoader( self.eval_raybatchstream, @@ -183,7 +175,7 @@ def setup_eval(self): num_workers=self.config.dataloader_num_workers, prefetch_factor=self.config.prefetch_factor, shuffle=False, - collate_fn=identity_collate, # Our dataset handles batching / collation of rays + collate_fn=identity_collate, # Our dataset handles batching / collation of rays ) self.iter_eval_raybundles = iter(self.eval_ray_dataloader) self.image_eval_dataloader = RandIndicesEvalDataloader( diff --git a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py new file mode 100644 index 0000000000..f21f039a8d --- /dev/null +++ b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py @@ -0,0 +1,237 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Parallel data manager that outputs cameras / images instead of raybundles. +""" + +from __future__ import annotations + +import random +from functools import cached_property +from pathlib import Path +from typing import Dict, ForwardRef, Generic, List, Literal, Tuple, Type, Union, cast, get_args, get_origin + +import fpsample +import numpy as np +import torch +from torch.nn import Parameter + +from nerfstudio.cameras.cameras import Cameras +from nerfstudio.data.datamanagers.base_datamanager import DataManager, TDataset +from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig +from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs +from nerfstudio.data.datasets.base_dataset import InputDataset +from nerfstudio.data.utils.data_utils import identity_collate +from nerfstudio.data.utils.dataloaders import ImageBatchStream, undistort_view +from nerfstudio.utils.misc import get_orig_class +from nerfstudio.utils.rich_utils import CONSOLE + + +class ParallelFullImageDatamanager(DataManager, Generic[TDataset]): + def __init__( + self, + config: FullImageDatamanagerConfig, + device: Union[torch.device, str] = "cpu", + test_mode: Literal["test", "val", "inference"] = "val", + world_size: int = 1, + local_rank: int = 0, + **kwargs, + ): + self.config = config + self.device = device + self.world_size = world_size + self.local_rank = local_rank + self.sampler = None + self.test_mode = test_mode + self.test_split = "test" if test_mode in ["test", "inference"] else "val" + self.dataparser_config = self.config.dataparser + if self.config.data is not None: + self.config.dataparser.data = Path(self.config.data) + else: + self.config.data = self.config.dataparser.data + self.dataparser = self.dataparser_config.setup() + if test_mode == "inference": + self.dataparser.downscale_factor = 1 # Avoid opening images + self.includes_time = self.dataparser.includes_time + + self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") + self.train_dataset = self.create_train_dataset() + self.eval_dataset = self.create_eval_dataset() + + if len(self.train_dataset) > 500 and self.config.cache_images == "gpu": + CONSOLE.print( + "Train dataset has over 500 images, overriding cache_images to cpu", + style="bold yellow", + ) + self.config.cache_images = "cpu" + self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device + if self.config.masks_on_gpu is True: + self.exclude_batch_keys_from_device.remove("mask") + if self.config.images_on_gpu is True: + self.exclude_batch_keys_from_device.remove("image") + + # Some logic to make sure we sample every camera in equal amounts + self.train_unseen_cameras = self.sample_train_cameras() + self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] + assert len(self.train_unseen_cameras) > 0, "No data found in dataset" + + super().__init__() + + def sample_train_cameras(self): + """Return a list of camera indices sampled using the strategy specified by + self.config.train_cameras_sampling_strategy""" + num_train_cameras = len(self.train_dataset) + if self.config.train_cameras_sampling_strategy == "random": + if not hasattr(self, "random_generator"): + self.random_generator = random.Random(self.config.train_cameras_sampling_seed) + indices = list(range(num_train_cameras)) + self.random_generator.shuffle(indices) + return indices + elif self.config.train_cameras_sampling_strategy == "fps": + if not hasattr(self, "train_unsampled_epoch_count"): + np.random.seed(self.config.train_cameras_sampling_seed) # fix random seed of fpsample + self.train_unsampled_epoch_count = np.zeros(num_train_cameras) + camera_origins = self.train_dataset.cameras.camera_to_worlds[..., 3].numpy() + # We concatenate camera origins with weighted train_unsampled_epoch_count because we want to + # increase the chance to sample camera that hasn't been sampled in consecutive epochs previously. + # We assume the camera origins are also rescaled, so the weight 0.1 is relative to the scale of scene + data = np.concatenate( + (camera_origins, 0.1 * np.expand_dims(self.train_unsampled_epoch_count, axis=-1)), axis=-1 + ) + n = self.config.fps_reset_every + if num_train_cameras < n: + CONSOLE.log( + f"num_train_cameras={num_train_cameras} is smaller than fps_reset_ever={n}, the behavior of " + "camera sampler will be very similar to sampling random without replacement (default setting)." + ) + n = num_train_cameras + kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(data, n, h=3) + + self.train_unsampled_epoch_count += 1 + self.train_unsampled_epoch_count[kdline_fps_samples_idx] = 0 + return kdline_fps_samples_idx.tolist() + else: + raise ValueError(f"Unknown train camera sampling strategy: {self.config.train_cameras_sampling_strategy}") + + def create_train_dataset(self) -> TDataset: + """Sets up the data loaders for training""" + return self.dataset_type( + dataparser_outputs=self.train_dataparser_outputs, + scale_factor=self.config.camera_res_scale_factor, + cache_compressed_images=self.config.cache_compressed_images, + ) + + def create_eval_dataset(self) -> TDataset: + """Sets up the data loaders for evaluation""" + return self.dataset_type( + dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), + scale_factor=self.config.camera_res_scale_factor, + cache_compressed_images=self.config.cache_compressed_images, + ) + + @cached_property + def dataset_type(self) -> Type[TDataset]: + """Returns the dataset type passed as the generic argument""" + default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore + orig_class: Type[ParallelFullImageDatamanager] = get_orig_class(self, default=None) # type: ignore + if type(self) is ParallelFullImageDatamanager and orig_class is None: + return default + if orig_class is not None and get_origin(orig_class) is ParallelFullImageDatamanager: + return get_args(orig_class)[0] + + # For inherited classes, we need to find the correct type to instantiate + for base in getattr(self, "__orig_bases__", []): + if get_origin(base) is ParallelFullImageDatamanager: + for value in get_args(base): + if isinstance(value, ForwardRef): + if value.__forward_evaluated__: + value = value.__forward_value__ + elif value.__forward_module__ is None: + value.__forward_module__ = type(self).__module__ + value = getattr(value, "_evaluate")(None, None, set()) + assert isinstance(value, type) + if issubclass(value, InputDataset): + return cast(Type[TDataset], value) + return default + + def get_datapath(self) -> Path: + return self.config.dataparser.data + + def setup_train(self): + self.train_imagebatch_stream = ImageBatchStream( + input_dataset=self.train_dataset, + cache_images_type=self.config.cache_images_type, + sampling_seed=self.config.train_cameras_sampling_seed, + device=self.device, + ) + self.train_image_dataloader = torch.utils.data.DataLoader( + self.train_imagebatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + collate_fn=identity_collate, + # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? + ) + self.iter_train_image_dataloader = iter(self.train_image_dataloader) + + def setup_eval(self): + self.eval_imagebatch_stream = ImageBatchStream( + input_dataset=self.eval_dataset, + cache_images_type=self.config.cache_images_type, + sampling_seed=self.config.train_cameras_sampling_seed, + device=self.device, + ) + self.eval_image_dataloader = torch.utils.data.DataLoader( + self.eval_imagebatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + collate_fn=identity_collate, + # pin_memory_device=self.device, + ) + self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) + + @property + def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: + return self.iter_eval_image_dataloader + + def get_param_groups(self) -> Dict[str, List[Parameter]]: + """Get the param groups for the data manager. + Returns: + A list of dictionaries containing the data manager's param groups. + """ + return {} + + def get_train_rays_per_batch(self): + # TODO: fix this to be the resolution of the last image rendered + return 800 * 800 + + def next_train(self, step: int) -> Tuple[Cameras, Dict]: + self.train_count += 1 + camera, data = next(self.iter_train_image_dataloader)[0] + return camera, data + + def next_eval(self, step: int) -> Tuple[Cameras, Dict]: + self.eval_count += 1 + camera, data = next(self.iter_train_image_dataloader)[0] + return camera, data + + def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: + """Returns the next evaluation batch + + Returns a Camera instead of raybundle""" + image_idx = self.eval_unseen_cameras.pop(random.randint(0, len(self.eval_unseen_cameras) - 1)) + # Make sure to re-populate the unseen cameras list if we have exhausted it + if len(self.eval_unseen_cameras) == 0: + self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] + return undistort_view(image_idx, self.eval_dataset, self.config.cache_images_type) diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index 2c7a1b1925..021b04d35d 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -46,7 +46,7 @@ class InputDataset(Dataset): exclude_batch_keys_from_device: List[str] = ["image", "mask"] cameras: Cameras - def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_images: bool = False): + def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_compressed_images: bool = False): super().__init__() self._dataparser_outputs = dataparser_outputs self.scale_factor = scale_factor @@ -55,9 +55,9 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = self.cameras = deepcopy(dataparser_outputs.cameras) self.cameras.rescale_output_resolution(scaling_factor=scale_factor) self.mask_color = dataparser_outputs.metadata.get("mask_color", None) - self.cache_images = cache_images - """If cache_images == True, cache all the image files into RAM in their compressed form (not as tensors yet)""" - if cache_images: + self.cache_compressed_images = cache_compressed_images + """If cache_compressed_images == True, cache all the image files into RAM in their compressed form (jpeg, png, etc. but not as pytorch tensors)""" + if cache_compressed_images: self.binary_images = [] self.binary_masks = [] for image_filename in self._dataparser_outputs.image_filenames: @@ -78,7 +78,7 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: image_idx: The image index in the dataset. """ image_filename = self._dataparser_outputs.image_filenames[image_idx] - if self.cache_images: + if self.cache_compressed_images: pil_image = Image.open(self.binary_images[image_idx]) else: pil_image = Image.open(image_filename) @@ -146,7 +146,7 @@ def get_data(self, image_idx: int, image_type: Literal["uint8", "float32"] = "fl data = {"image_idx": image_idx, "image": image} if self._dataparser_outputs.mask_filenames is not None: - if self.cache_images: + if self.cache_compressed_images: mask_filepath = self.binary_masks[image_idx] else: mask_filepath = self._dataparser_outputs.mask_filenames[image_idx] diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index c35830e5cf..3dc9f98179 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -21,22 +21,25 @@ import multiprocessing import random from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Sized, Tuple, Union, cast -from dataclasses import field +from collections import defaultdict +from typing import Any, Callable, Dict, List, Literal, Optional, Sized, Tuple, Union, cast +import cv2 +import numpy as np import torch from rich.progress import track from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader -from nerfstudio.cameras.cameras import Cameras +from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper +from nerfstudio.cameras.cameras import Cameras, CameraType from nerfstudio.cameras.rays import RayBundle from nerfstudio.data.datasets.base_dataset import InputDataset +from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate +from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import get_dict_to_torch from nerfstudio.utils.rich_utils import CONSOLE -from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.model_components.ray_generators import RayGenerator def variable_res_collate(batch: List[Dict]) -> Dict: @@ -68,124 +71,337 @@ def variable_res_collate(batch: List[Dict]) -> Dict: return new_batch -# class CacheDataloader(DataLoader): -# """Collated image dataset that implements caching of default-pytorch-collatable data. -# Creates batches of the InputDataset return type. - -# Args: -# dataset: Dataset to sample from. -# num_samples_to_collate: How many images to sample rays for each batch. -1 for all images. -# num_times_to_repeat_images: How often to yield an image batch before resampling. -1 to never pick new images. -# device: Device to perform computation. -# collate_fn: The function we will use to collate our training data -# """ - -# def __init__( -# self, -# dataset: Dataset, -# num_images_to_sample_from: int = -1, -# num_times_to_repeat_images: int = -1, -# device: Union[torch.device, str] = "cpu", -# collate_fn: Callable[[Any], Any] = nerfstudio_collate, -# exclude_batch_keys_from_device: Optional[List[str]] = None, -# **kwargs, -# ): -# if exclude_batch_keys_from_device is None: -# exclude_batch_keys_from_device = ["image"] -# self.dataset = dataset -# assert isinstance(self.dataset, Sized) - -# super().__init__(dataset=dataset, **kwargs) # This will set self.dataset -# self.num_times_to_repeat_images = num_times_to_repeat_images -# self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) -# self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from -# self.device = device -# self.collate_fn = collate_fn -# self.num_workers = kwargs.get("num_workers", 0) -# self.exclude_batch_keys_from_device = exclude_batch_keys_from_device - -# self.num_repeated = self.num_times_to_repeat_images # starting value -# self.first_time = True - -# self.cached_collated_batch = None -# if self.cache_all_images: -# CONSOLE.print(f"Caching all {len(self.dataset)} images.") -# if len(self.dataset) > 500: -# CONSOLE.print( -# "[bold yellow]Warning: If you run out of memory, try reducing the number of images to sample from." -# ) -# self.cached_collated_batch = self._get_collated_batch() -# elif self.num_times_to_repeat_images == -1: -# CONSOLE.print( -# f"Caching {self.num_images_to_sample_from} out of {len(self.dataset)} images, without resampling." -# ) -# else: -# CONSOLE.print( -# f"Caching {self.num_images_to_sample_from} out of {len(self.dataset)} images, " -# f"resampling every {self.num_times_to_repeat_images} iters." -# ) - -# def __getitem__(self, idx): -# return self.dataset.__getitem__(idx) - -# def _get_batch_list(self): -# """Returns a list of batches from the dataset attribute.""" - -# assert isinstance(self.dataset, Sized) -# indices = random.sample(range(len(self.dataset)), k=self.num_images_to_sample_from) -# batch_list = [] -# results = [] - -# num_threads = int(self.num_workers) * 4 -# num_threads = min(num_threads, multiprocessing.cpu_count() - 1) -# num_threads = max(num_threads, 1) - -# with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: -# for idx in indices: -# res = executor.submit(self.dataset.__getitem__, idx) -# results.append(res) - -# for res in track(results, description="Loading data batch", transient=True): -# batch_list.append(res.result()) - -# return batch_list - -# def _get_collated_batch(self): -# """Returns a collated batch of images.""" -# batch_list = self._get_batch_list() -# collated_batch = self.collate_fn(batch_list) -# collated_batch = get_dict_to_torch( -# collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device -# ) -# return collated_batch - -# def __iter__(self): -# while True: -# if self.cache_all_images: -# collated_batch = self.cached_collated_batch -# elif self.first_time or ( -# self.num_times_to_repeat_images != -1 and self.num_repeated >= self.num_times_to_repeat_images -# ): -# # trigger a reset -# self.num_repeated = 0 -# collated_batch = self._get_collated_batch() -# # possibly save a cached item -# self.cached_collated_batch = collated_batch if self.num_times_to_repeat_images != 0 else None -# self.first_time = False -# else: -# collated_batch = self.cached_collated_batch -# self.num_repeated += 1 -# yield collated_batch +def _undistort_image( + camera: Cameras, distortion_params: np.ndarray, data: dict, image: np.ndarray, K: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, Optional[torch.Tensor]]: + mask = None + if camera.camera_type.item() == CameraType.PERSPECTIVE.value: + assert distortion_params[3] == 0, ( + "We don't support the 4th Brown parameter for image undistortion, " + "Only k1, k2, k3, p1, p2 can be non-zero." + ) + # we rearrange the distortion parameters because OpenCV expects the order (k1, k2, p1, p2, k3) + # see https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html + distortion_params = np.array( + [ + distortion_params[0], + distortion_params[1], + distortion_params[4], + distortion_params[5], + distortion_params[2], + distortion_params[3], + 0, + 0, + ] + ) + # because OpenCV expects the pixel coord to be top-left, we need to shift the principal point by 0.5 + # see https://github.com/nerfstudio-project/nerfstudio/issues/3048 + K[0, 2] = K[0, 2] - 0.5 + K[1, 2] = K[1, 2] - 0.5 + if np.any(distortion_params): + newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0) + image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore + else: + newK = K + roi = 0, 0, image.shape[1], image.shape[0] + # crop the image and update the intrinsics accordingly + x, y, w, h = roi + image = image[y : y + h, x : x + w] + newK[0, 2] -= x + newK[1, 2] -= y + + if "depth_image" in data: + data["depth_image"] = data["depth_image"][y : y + h, x : x + w] + if "mask" in data: + mask = data["mask"].numpy() + mask = mask.astype(np.uint8) * 255 + if np.any(distortion_params): + mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore + mask = mask[y : y + h, x : x + w] + mask = torch.from_numpy(mask).bool() + if len(mask.shape) == 2: + mask = mask[:, :, None] + newK[0, 2] = newK[0, 2] + 0.5 + newK[1, 2] = newK[1, 2] + 0.5 + K = newK + + elif camera.camera_type.item() == CameraType.FISHEYE.value: + K[0, 2] = K[0, 2] - 0.5 + K[1, 2] = K[1, 2] - 0.5 + distortion_params = np.array( + [distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]] + ) + newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( + K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0 + ) + map1, map2 = cv2.fisheye.initUndistortRectifyMap( + K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1 + ) + # and then remap: + image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR) + if "mask" in data: + mask = data["mask"].numpy() + mask = mask.astype(np.uint8) * 255 + mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK) + mask = torch.from_numpy(mask).bool() + if len(mask.shape) == 2: + mask = mask[:, :, None] + newK[0, 2] = newK[0, 2] + 0.5 + newK[1, 2] = newK[1, 2] + 0.5 + K = newK + elif camera.camera_type.item() == CameraType.FISHEYE624.value: + fisheye624_params = torch.cat( + [camera.fx, camera.fy, camera.cx, camera.cy, torch.from_numpy(distortion_params)], dim=0 + ) + assert fisheye624_params.shape == (16,) + assert ( + "mask" not in data + and camera.metadata is not None + and "fisheye_crop_radius" in camera.metadata + and isinstance(camera.metadata["fisheye_crop_radius"], float) + ) + fisheye_crop_radius = camera.metadata["fisheye_crop_radius"] + + # Approximate the FOV of the unmasked region of the camera. + upper, lower, left, right = fisheye624_unproject_helper( + torch.tensor( + [ + [camera.cx, camera.cy - fisheye_crop_radius], + [camera.cx, camera.cy + fisheye_crop_radius], + [camera.cx - fisheye_crop_radius, camera.cy], + [camera.cx + fisheye_crop_radius, camera.cy], + ], + dtype=torch.float32, + )[None], + params=fisheye624_params[None], + ).squeeze(dim=0) + fov_radians = torch.max( + torch.acos(torch.sum(upper * lower / torch.linalg.norm(upper) / torch.linalg.norm(lower))), + torch.acos(torch.sum(left * right / torch.linalg.norm(left) / torch.linalg.norm(right))), + ) + + # Heuristics to determine parameters of an undistorted image. + undist_h = int(fisheye_crop_radius * 2) + undist_w = int(fisheye_crop_radius * 2) + undistort_focal = undist_h / (2 * torch.tan(fov_radians / 2.0)) + undist_K = torch.eye(3) + undist_K[0, 0] = undistort_focal # fx + undist_K[1, 1] = undistort_focal # fy + undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0). + undist_K[1, 2] = (undist_h - 1) / 2.0 # cy + + # Undistorted 2D coordinates -> rays -> reproject to distorted UV coordinates. + undist_uv_homog = torch.stack( + [ + *torch.meshgrid( + torch.arange(undist_w, dtype=torch.float32), + torch.arange(undist_h, dtype=torch.float32), + ), + torch.ones((undist_w, undist_h), dtype=torch.float32), + ], + dim=-1, + ) + assert undist_uv_homog.shape == (undist_w, undist_h, 3) + dist_uv = ( + fisheye624_project( + xyz=( + torch.einsum( + "ij,bj->bi", + torch.linalg.inv(undist_K), + undist_uv_homog.reshape((undist_w * undist_h, 3)), + )[None] + ), + params=fisheye624_params[None, :], + ) + .reshape((undist_w, undist_h, 2)) + .numpy() + ) + map1 = dist_uv[..., 1] + map2 = dist_uv[..., 0] + + # Use correspondence to undistort image. + image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR) + + # Compute undistorted mask as well. + dist_h = camera.height.item() + dist_w = camera.width.item() + mask = np.mgrid[:dist_h, :dist_w] + mask[0, ...] -= dist_h // 2 + mask[1, ...] -= dist_w // 2 + mask = np.linalg.norm(mask, axis=0) < fisheye_crop_radius + mask = torch.from_numpy( + cv2.remap( + mask.astype(np.uint8) * 255, + map1, + map2, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0, + ) + / 255.0 + ).bool()[..., None] + if len(mask.shape) == 2: + mask = mask[:, :, None] + assert mask.shape == (undist_h, undist_w, 1) + K = undist_K.numpy() + else: + raise NotImplementedError("Only perspective and fisheye cameras are supported") + return K, image, mask + + +def undistort_view( + idx: int, dataset: InputDataset, image_type: Literal["uint8", "float32"] = "float32" +) -> Dict[str, torch.Tensor]: + """Undistorts an image to one taken by a linear (pinhole) camera model and returns a new Camera with these updated intrinsics + Note: this method does not modify the dataset's attributes at all. + + Returns: The undistorted data (image, depth, mask, etc.) and the new linear Camera object + """ + data = dataset.get_data(idx, image_type) + camera = dataset.cameras[idx].reshape(()) + assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( + f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' + f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' + ) + if camera.distortion_params is None or torch.all(camera.distortion_params == 0): + return data + K = camera.get_intrinsics_matrices().numpy() + distortion_params = camera.distortion_params.numpy() + image = data["image"].numpy() + K, image, mask = _undistort_image(camera, distortion_params, data, image, K) + data["image"] = torch.from_numpy(image) + if mask is not None: + data["mask"] = mask + + # create a new Camera with the rectified / undistorted intrinsics + new_camera = Cameras( + camera_to_worlds=camera.camera_to_worlds.unsqueeze(0), + fx=torch.Tensor([[float(K[0, 0])]]), + fy=torch.Tensor([[float(K[1, 1])]]), + cx=torch.Tensor([[float(K[0, 2])]]), + cy=torch.Tensor([[float(K[1, 2])]]), + width=torch.Tensor([[image.shape[1]]]).to(torch.int32), + height=torch.Tensor([[image.shape[0]]]).to(torch.int32), + ) + return new_camera, data + + +class CacheDataloader(DataLoader): + """Collated image dataset that implements caching of default-pytorch-collatable data. + Creates batches of the InputDataset return type. + + Args: + dataset: Dataset to sample from. + num_samples_to_collate: How many images to sample rays for each batch. -1 for all images. + num_times_to_repeat_images: How often to yield an image batch before resampling. -1 to never pick new images. + device: Device to perform computation. + collate_fn: The function we will use to collate our training data + """ + + def __init__( + self, + dataset: Dataset, + num_images_to_sample_from: int = -1, + num_times_to_repeat_images: int = -1, + device: Union[torch.device, str] = "cpu", + collate_fn: Callable[[Any], Any] = nerfstudio_collate, + exclude_batch_keys_from_device: Optional[List[str]] = None, + **kwargs, + ): + if exclude_batch_keys_from_device is None: + exclude_batch_keys_from_device = ["image"] + self.dataset = dataset + assert isinstance(self.dataset, Sized) + + super().__init__(dataset=dataset, **kwargs) # This will set self.dataset + self.num_times_to_repeat_images = num_times_to_repeat_images + self.cache_all_images = (num_images_to_sample_from == -1) or (num_images_to_sample_from >= len(self.dataset)) + self.num_images_to_sample_from = len(self.dataset) if self.cache_all_images else num_images_to_sample_from + self.device = device + self.collate_fn = collate_fn + self.num_workers = kwargs.get("num_workers", 0) + self.exclude_batch_keys_from_device = exclude_batch_keys_from_device + + self.num_repeated = self.num_times_to_repeat_images # starting value + self.first_time = True + + self.cached_collated_batch = None + if self.cache_all_images: + CONSOLE.print(f"Caching all {len(self.dataset)} images.") + if len(self.dataset) > 500: + CONSOLE.print( + "[bold yellow]Warning: If you run out of memory, try reducing the number of images to sample from." + ) + self.cached_collated_batch = self._get_collated_batch() + elif self.num_times_to_repeat_images == -1: + CONSOLE.print( + f"Caching {self.num_images_to_sample_from} out of {len(self.dataset)} images, without resampling." + ) + else: + CONSOLE.print( + f"Caching {self.num_images_to_sample_from} out of {len(self.dataset)} images, " + f"resampling every {self.num_times_to_repeat_images} iters." + ) + + def __getitem__(self, idx): + return self.dataset.__getitem__(idx) + + def _get_batch_list(self): + """Returns a list of batches from the dataset attribute.""" + + assert isinstance(self.dataset, Sized) + indices = random.sample(range(len(self.dataset)), k=self.num_images_to_sample_from) + batch_list = [] + results = [] + + num_threads = int(self.num_workers) * 4 + num_threads = min(num_threads, multiprocessing.cpu_count() - 1) + num_threads = max(num_threads, 1) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for idx in indices: + res = executor.submit(self.dataset.__getitem__, idx) + results.append(res) + + for res in track(results, description="Loading data batch", transient=True): + batch_list.append(res.result()) + + return batch_list + + def _get_collated_batch(self): + """Returns a collated batch of images.""" + batch_list = self._get_batch_list() + collated_batch = self.collate_fn(batch_list) + collated_batch = get_dict_to_torch( + collated_batch, device=self.device, exclude=self.exclude_batch_keys_from_device + ) + return collated_batch + + def __iter__(self): + while True: + if self.cache_all_images: + collated_batch = self.cached_collated_batch + elif self.first_time or ( + self.num_times_to_repeat_images != -1 and self.num_repeated >= self.num_times_to_repeat_images + ): + # trigger a reset + self.num_repeated = 0 + collated_batch = self._get_collated_batch() + # possibly save a cached item + self.cached_collated_batch = collated_batch if self.num_times_to_repeat_images != 0 else None + self.first_time = False + else: + collated_batch = self.cached_collated_batch + self.num_repeated += 1 + yield collated_batch + -import concurrent.futures import math -import multiprocessing -import random -from typing import Sized + from torch.utils.data import Dataset -from nerfstudio.utils.misc import get_dict_to_torch from tqdm.auto import tqdm + class RayBatchStream(torch.utils.data.IterableDataset): """Wrapper around Pytorch's IterableDataset to generate the next batch of rays (next RayBundle) and corresponding labels with multiple parallel workers. @@ -195,20 +411,20 @@ class RayBatchStream(torch.utils.data.IterableDataset): bottlenecked by disk read speed. To avoid Out-Of-Memory (OOM) errors, this batch of images is small and regenerated by resampling the worker's partition of images to maintain sampling diversity. """ + def __init__( self, - input_dataset: Dataset, + input_dataset: InputDataset, num_rays_per_batch: int = 1024, num_images_to_sample_from: int = -1, num_times_to_repeat_images: int = -1, device: Union[torch.device, str] = "cpu", # variable_res_collate avoids np.stack'ing images, which allows it to be much faster than `nerfstudio_collate` - collate_fn: Callable[[Any], Any] = cast(Any, staticmethod(variable_res_collate)), + collate_fn: Callable[[Any], Any] = cast(Any, staticmethod(variable_res_collate)), num_image_load_threads: int = 4, exclude_batch_keys_from_device: Optional[List[str]] = None, load_from_disk: bool = False, patch_size: int = 1, - ): if exclude_batch_keys_from_device is None: exclude_batch_keys_from_device = ["image"] @@ -222,7 +438,7 @@ def __init__( """How many RayBundles to generate from this batch of images after sampling `num_images_to_sample_from` images.""" self.device = device """If a CUDA GPU is present, self.device will be set to use that GPU.""" - self.collate_fn = collate_fn + self.collate_fn = collate_fn """What collate function is used to batch images to be used for pixel sampling and ray generation. """ self.num_image_load_threads = num_image_load_threads """Number of threads created to read images from disk and form collated batches.""" @@ -247,9 +463,7 @@ def _get_pixel_sampler(self, dataset: Dataset, num_rays_per_batch: int) -> Pixel from nerfstudio.cameras.cameras import CameraType if self.patch_size > 1 and type(self.pixel_sampler_config) is PixelSamplerConfig: - return PatchPixelSamplerConfig().setup( - patch_size=self.patch_size, num_rays_per_batch=num_rays_per_batch - ) + return PatchPixelSamplerConfig().setup(patch_size=self.patch_size, num_rays_per_batch=num_rays_per_batch) is_equirectangular = (dataset.cameras.camera_type == CameraType.EQUIRECTANGULAR.value).all() if is_equirectangular.any(): CONSOLE.print("[bold yellow]Warning: Some cameras are equirectangular, but using default pixel sampler.") @@ -292,10 +506,10 @@ def _get_batch_list(self, indices=None): for idx in indices: res = executor.submit(self.input_dataset.__getitem__, idx) results.append(res) - results = tqdm(results) # this is temporary and will be removed in the final push + results = tqdm(results) # this is temporary and will be removed in the final push for res in results: batch_list.append(res.result()) - + return batch_list def _get_collated_batch(self, indices=None): @@ -322,31 +536,33 @@ def __iter__(self): else: # we only have a single process per_worker = len(self.input_dataset) slice_start = 0 - dataset_indices = list( - range(len(self.input_dataset)) - ) + dataset_indices = list(range(len(self.input_dataset))) worker_indices = dataset_indices[ slice_start : slice_start + per_worker ] # the indices of the datapoints in the dataset this worker will load if self.enable_per_worker_image_caching: self._cached_collated_batch = self._get_collated_batch(worker_indices) r = random.Random(3301) - num_rays_per_loop = self.num_rays_per_batch # default train_num_rays_per_batch is 4096 + num_rays_per_loop = self.num_rays_per_batch # default train_num_rays_per_batch is 4096 # each worker has its own pixel sampler worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) - self.ray_generator = RayGenerator(self.input_dataset.cameras) # the generated RayBundles will be on the same device as self.input_dataset.cameras (CPU) - + self.ray_generator = RayGenerator( + self.input_dataset.cameras + ) # the generated RayBundles will be on the same device as self.input_dataset.cameras (CPU) + i = 0 while True: if self.enable_per_worker_image_caching: collated_batch = self._cached_collated_batch elif i % self.num_times_to_repeat_images == 0: r.shuffle(worker_indices) - - if self.num_images_to_sample_from == -1: # if -1, the worker gets all available indices in its partition + + if ( + self.num_images_to_sample_from == -1 + ): # if -1, the worker gets all available indices in its partition image_indices = worker_indices - else: # get a total of 'num_images_to_sample_from' image indices - image_indices = worker_indices[:self.num_images_to_sample_from] + else: # get a total of 'num_images_to_sample_from' image indices + image_indices = worker_indices[: self.num_images_to_sample_from] collated_batch = self._get_collated_batch(image_indices) i += 1 @@ -358,13 +574,68 @@ def __iter__(self): What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask, and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol) """ - batch = worker_pixel_sampler.sample(collated_batch) # the pixel_sampler will sample num_rays_per_batch pixels. + batch = worker_pixel_sampler.sample( + collated_batch + ) # the pixel_sampler will sample num_rays_per_batch pixels. # collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image' ray_indices = batch["indices"] - ray_bundle = self.ray_generator(ray_indices).to(self.device) # the ray_bundle is on the GPU; batch["image"] is on the CPU + ray_bundle = self.ray_generator(ray_indices).to( + self.device + ) # the ray_bundle is on the GPU; batch["image"] is on the CPU yield ray_bundle, batch +class ImageBatchStream(torch.utils.data.IterableDataset): + """ + A wrapper of InputDataset that outputs undistorted full images and cameras. This makes the + datamanager more lightweight since we don't have to do generate rays. Useful for full-image + training e.g. rasterization pipelines + """ + + def __init__( + self, + input_dataset: InputDataset, + cache_images_type: Literal["uint8", "float32"] = "float32", + sampling_seed: int = 3301, + device: Union[torch.device, str] = "cpu", + ): + self.input_dataset = input_dataset + self.cache_images_type = cache_images_type + self.sampling_seed = sampling_seed + self.device = device + + def __iter__(self): + # print(self.input_dataset.cameras.device) prints cpu + dataset_indices = list(range(len(self.input_dataset))) + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: # if we have multiple processes + per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) + slice_start = worker_info.id * per_worker + else: # we only have a single process + per_worker = len(self.input_dataset) + slice_start = 0 + worker_indices = dataset_indices[ + slice_start : slice_start + per_worker + ] # the indices of the datapoints in the dataset this worker will load + r = random.Random(self.sampling_seed) + r.shuffle(worker_indices) + i = 0 # i refers to what image index we are outputting: i=0 => we are yielding our first image,camera + print("HELLO", worker_info.id) + while True: + if i >= len( + worker_indices + ): # if we've iterated through all the worker's partition of images, we need to reshuffle + r.shuffle(worker_indices) + i = 0 + idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve + camera, data = undistort_view(idx, self.input_dataset, self.cache_images_type) + if camera.metadata is None: + camera.metadata = {} + camera.metadata["cam_idx"] = idx + i += 1 + yield camera, data + + class EvalDataloader(DataLoader): """Evaluation dataloader base class diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 731f214e77..969ef82c0c 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -15,6 +15,7 @@ """ Abstracts for the Pipeline class. """ + from __future__ import annotations import typing @@ -298,6 +299,9 @@ def get_train_loss_dict(self, step: int): step: current iteration step to update sampler if using DDP (distributed) """ ray_bundle, batch = self.datamanager.next_train(step) + ray_bundle = ray_bundle.to( + self.device + ) # for some reason this line of code is needed otherwise viewmats will not be invertible? model_outputs = self._model(ray_bundle) # train distributed data parallel model if world_size > 1 metrics_dict = self.model.get_metrics_dict(model_outputs, batch) loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict) From 4ebad857b0dd7eecf7e26a4474b2b0716d1935b8 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 23 Sep 2024 02:13:57 -0700 Subject: [PATCH 59/78] added custom ray processing API to support implementations like LERF, cleaned up FullImageDatamanager to original because of new ParallelImageDatamanger --- .../datamanagers/full_images_datamanager.py | 152 +----------------- .../data/datamanagers/parallel_datamanager.py | 4 + nerfstudio/data/utils/dataloaders.py | 48 +++--- 3 files changed, 28 insertions(+), 176 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 46d83e8311..f3a08d4188 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -45,6 +45,7 @@ from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.utils.data_utils import identity_collate +from nerfstudio.data.utils.dataloaders import ImageBatchStream from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -611,154 +612,3 @@ def _undistort_image( else: raise NotImplementedError("Only perspective and fisheye cameras are supported") return K, image, mask - - -# def undistort_view(idx: int, dataset: TDataset, image_type: Literal["uint8", "float32"] = "float32") -> Dict[str, torch.Tensor]: -# """Undistorts an image to one taken by a linear (pinhole) camera model and returns a new Camera with these updated intrinsics -# Note: this method does not modify the dataset's attributes at all. - -# Returns: The undistorted data (image, depth, mask, etc.) and the new linear Camera object -# """ -# data = dataset.get_data(idx, image_type) -# camera = dataset.cameras[idx].reshape(()) -# assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), ( -# f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' -# f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' -# ) -# if camera.distortion_params is None or torch.all(camera.distortion_params == 0): -# return data -# K = camera.get_intrinsics_matrices().numpy() -# distortion_params = camera.distortion_params.numpy() -# image = data["image"].numpy() -# K, image, mask = _undistort_image(camera, distortion_params, data, image, K) -# data["image"] = torch.from_numpy(image) -# if mask is not None: -# data["mask"] = mask - -# # create a new Camera with the rectified / undistorted intrinsics -# new_camera = Cameras( -# camera_to_worlds=camera.camera_to_worlds.unsqueeze(0), -# fx=torch.Tensor([[float(K[0, 0])]]), -# fy=torch.Tensor([[float(K[1, 1])]]), -# cx=torch.Tensor([[float(K[0, 2])]]), -# cy=torch.Tensor([[float(K[1, 2])]]), -# width=torch.Tensor([[image.shape[1]]]).to(torch.int32), -# height=torch.Tensor([[image.shape[0]]]).to(torch.int32), -# ) -# return data, new_camera - -import math - - -class ImageBatchStream(torch.utils.data.IterableDataset): - """ - A wrapper of InputDataset that outputs undistorted full images and cameras. This makes the - datamanager more lightweight since we don't have to do generate rays. Useful for full-image - training e.g. rasterization pipelines - """ - - def __init__( - self, - datamanager_config: DataManagerConfig, - input_dataset: TDataset, - device, - ): - self.config = datamanager_config - self.input_dataset = input_dataset - self.device = device - - def __iter__(self): - # print(self.input_dataset.cameras.device) prints cpu - dataset_indices = list(range(len(self.input_dataset))) - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: # if we have multiple processes - per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) - slice_start = worker_info.id * per_worker - else: # we only have a single process - per_worker = len(self.input_dataset) - slice_start = 0 - worker_indices = dataset_indices[ - slice_start : slice_start + per_worker - ] # the indices of the datapoints in the dataset this worker will load - r = random.Random(self.config.train_cameras_sampling_seed) - r.shuffle(worker_indices) - i = 0 # i refers to what image index we are outputting: i=0 => we are yielding our first image,camera - while True: - if i >= len( - worker_indices - ): # if we've iterated through all the worker's partition of images, we need to reshuffle - r.shuffle(worker_indices) - i = 0 - idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve - data, camera = undistort_view(idx, self.input_dataset, self.config.cache_images_type) - if camera.metadata is None: - camera.metadata = {} - camera.metadata["cam_idx"] = idx - i += 1 - yield camera, data - - -# class ParallelFullImageDatamanager(FullImageDatamanager, Generic[TDataset]): -# def __init__( -# self, -# config: FullImageDatamanagerConfig, -# device: Union[torch.device, str] = "cpu", -# test_mode: Literal["test", "val", "inference"] = "val", -# world_size: int = 1, -# local_rank: int = 0, -# **kwargs -# ): -# import torch.multiprocessing as mp -# mp.set_start_method("spawn") -# super().__init__( -# config=config, -# device=device, -# test_mode=test_mode, -# world_size=world_size, -# local_rank=local_rank, -# **kwargs -# ) - -# def setup_train(self): -# self.train_imagebatch_stream = ImageBatchStream( -# input_dataset=self.train_dataset, -# datamanager_config=self.config, -# device=self.device, -# ) -# self.train_image_dataloader = torch.utils.data.DataLoader( -# self.train_imagebatch_stream, -# batch_size=1, -# num_workers=self.config.dataloader_num_workers, -# collate_fn=identity_collate, -# # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? -# ) -# self.iter_train_image_dataloader = iter(self.train_image_dataloader) - -# def setup_eval(self): -# self.eval_imagebatch_stream = ImageBatchStream( -# input_dataset=self.eval_dataset, -# datamanager_config=self.config, -# device=self.device, -# ) -# self.eval_image_dataloader = torch.utils.data.DataLoader( -# self.eval_imagebatch_stream, -# batch_size=1, -# num_workers=self.config.dataloader_num_workers, -# collate_fn=identity_collate, -# # pin_memory_device=self.device, -# ) -# self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) - -# @property -# def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: -# return self.iter_eval_image_dataloader - -# def next_train(self, step: int) -> Tuple[Cameras, Dict]: -# self.train_count += 1 -# camera, data = next(self.iter_train_image_dataloader)[0] -# return camera, data - -# def next_eval(self, step: int) -> Tuple[Cameras, Dict]: -# self.eval_count += 1 -# camera, data = next(self.iter_train_image_dataloader)[0] -# return camera, data diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index 29512af124..68b7a75d32 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -221,3 +221,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: A list of dictionaries containing the data manager's param groups. """ return {} + + def custom_ray_processor(self, ray_bundle: RayBundle, batch: Dict) -> Tuple[RayBundle, Dict]: + """An API to add latents, metadata, or other further customization to the RayBundle dataloading process that is parallelized""" + return ray_bundle, batch diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 3dc9f98179..b1c94ef069 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -425,6 +425,7 @@ def __init__( exclude_batch_keys_from_device: Optional[List[str]] = None, load_from_disk: bool = False, patch_size: int = 1, + custom_ray_processor: Optional[Callable[[RayBundle, Dict], Tuple[RayBundle, Dict]]] = None, ): if exclude_batch_keys_from_device is None: exclude_batch_keys_from_device = ["image"] @@ -447,16 +448,17 @@ def __init__( For instance, if you would like to conserve GPU memory, don't move the image tensors to the GPU, which comes at a cost of total training time. The default value is ['image'].""" self.load_from_disk = load_from_disk + """If True, conserves RAM memory by loading images from disk. + If False, each worker caches all the images in its dataset partition as tensors to RAM and loads from RAM.""" self.patch_size = patch_size """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - self.enable_per_worker_image_caching = load_from_disk == False - """If True, each worker's will cache its entire partition of the image dataset as image tensors in RAM.""" self._cached_collated_batch = None """Each worker has a self._cached_collated_batch contains a collated batch of images cached in RAM for a specific worker that's ready for pixel sampling.""" self.pixel_sampler_config: PixelSamplerConfig = PixelSamplerConfig() """Specifies the pixel sampler config used to sample pixels from images. Each worker will have its own pixel sampler""" self.ray_generator: RayGenerator = None """Each worker will have its own ray generator, so this is set to None for now.""" + self.custom_ray_processor = custom_ray_processor def _get_pixel_sampler(self, dataset: Dataset, num_rays_per_batch: int) -> PixelSampler: """copied from VanillaDataManager.""" @@ -492,11 +494,7 @@ def _get_batch_list(self, indices=None): batch_list = [] results = [] - num_threads = ( - int(self.num_image_load_threads) - if not self.enable_per_worker_image_caching - else 4 * int(self.num_image_load_threads) - ) + num_threads = int(self.num_image_load_threads) if self.load_from_disk else 4 * int(self.num_image_load_threads) num_threads = min(num_threads, multiprocessing.cpu_count() - 1) num_threads = max(num_threads, 1) @@ -537,31 +535,31 @@ def __iter__(self): per_worker = len(self.input_dataset) slice_start = 0 dataset_indices = list(range(len(self.input_dataset))) - worker_indices = dataset_indices[ - slice_start : slice_start + per_worker - ] # the indices of the datapoints in the dataset this worker will load - if self.enable_per_worker_image_caching: + # the indices of the datapoints in the dataset this worker will load + worker_indices = dataset_indices[slice_start : slice_start + per_worker] + if not self.load_from_disk: self._cached_collated_batch = self._get_collated_batch(worker_indices) r = random.Random(3301) num_rays_per_loop = self.num_rays_per_batch # default train_num_rays_per_batch is 4096 + # each worker has its own pixel sampler worker_pixel_sampler = self._get_pixel_sampler(self.input_dataset, num_rays_per_loop) - self.ray_generator = RayGenerator( - self.input_dataset.cameras - ) # the generated RayBundles will be on the same device as self.input_dataset.cameras (CPU) + + # the generated RayBundles will be on the same device as self.input_dataset.cameras (CPU) + self.ray_generator = RayGenerator(self.input_dataset.cameras) i = 0 while True: - if self.enable_per_worker_image_caching: + if not self.load_from_disk: collated_batch = self._cached_collated_batch elif i % self.num_times_to_repeat_images == 0: r.shuffle(worker_indices) - if ( - self.num_images_to_sample_from == -1 - ): # if -1, the worker gets all available indices in its partition + if self.num_images_to_sample_from == -1: + # if -1, the worker gets all available indices in its partition image_indices = worker_indices - else: # get a total of 'num_images_to_sample_from' image indices + else: + # get a total of 'num_images_to_sample_from' image indices image_indices = worker_indices[: self.num_images_to_sample_from] collated_batch = self._get_collated_batch(image_indices) @@ -574,14 +572,14 @@ def __iter__(self): What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask, and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol) """ - batch = worker_pixel_sampler.sample( - collated_batch - ) # the pixel_sampler will sample num_rays_per_batch pixels. + batch = worker_pixel_sampler.sample(collated_batch) # type: ignore # collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image' ray_indices = batch["indices"] - ray_bundle = self.ray_generator(ray_indices).to( - self.device - ) # the ray_bundle is on the GPU; batch["image"] is on the CPU + # the ray_bundle is on the GPU; batch["image"] is on the CPU, here we move it to the GPU + ray_bundle = self.ray_generator(ray_indices).to(self.device) + if self.custom_ray_processor: + ray_bundle, batch = self.custom_ray_processor(ray_bundle, batch) + yield ray_bundle, batch From 87921beea02f21d46b6d602abb30023679d158f0 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Tue, 24 Sep 2024 01:51:36 -0700 Subject: [PATCH 60/78] adding functionality for ns-eval by adding FixedIndicesEvalDataloader to the setup_eval --- nerfstudio/data/datamanagers/parallel_datamanager.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index 68b7a75d32..5ce431643e 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -37,13 +37,13 @@ from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PixelSampler from nerfstudio.data.utils.data_utils import identity_collate -from nerfstudio.data.utils.dataloaders import RandIndicesEvalDataloader, RayBatchStream +from nerfstudio.data.utils.dataloaders import FixedIndicesEvalDataloader, RandIndicesEvalDataloader, RayBatchStream from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE class ParallelDataManager(DataManager, Generic[TDataset]): - """Data manager implementation for parallel dataloading + """Data manager implementation for parallel dataloading. Args: config: the DataManagerConfig used to instantiate class @@ -183,6 +183,12 @@ def setup_eval(self): device=self.device, num_workers=self.world_size * 4, ) + # this is used for ns-eval + self.fixed_indices_eval_dataloader = FixedIndicesEvalDataloader( + input_dataset=self.eval_dataset, + device=self.device, + num_workers=self.world_size * 4, + ) def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" From b628c7c744aebae4d6aee35c89d9e83f2b868d91 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 27 Sep 2024 02:42:04 -0700 Subject: [PATCH 61/78] adding both ray API and image-view API to datamanagers for custom parallelization --- nerfstudio/data/datamanagers/parallel_datamanager.py | 2 ++ .../data/datamanagers/parallel_full_images_datamanager.py | 6 ++++++ nerfstudio/data/utils/dataloaders.py | 8 +++++++- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index 5ce431643e..925176942e 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -148,6 +148,7 @@ def setup_train(self): device=self.device, collate_fn=variable_res_collate, load_from_disk=self.config.load_from_disk, + custom_ray_processor=self.custom_ray_processor, ) self.train_ray_dataloader = torch.utils.data.DataLoader( self.train_raybatchstream, @@ -168,6 +169,7 @@ def setup_eval(self): device=self.device, collate_fn=variable_res_collate, load_from_disk=True, + custom_ray_processor=self.custom_ray_processor, ) self.eval_ray_dataloader = torch.utils.data.DataLoader( self.eval_raybatchstream, diff --git a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py index f21f039a8d..4babc77242 100644 --- a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py @@ -175,6 +175,7 @@ def setup_train(self): cache_images_type=self.config.cache_images_type, sampling_seed=self.config.train_cameras_sampling_seed, device=self.device, + custom_view_processor=self.custom_view_processor, ) self.train_image_dataloader = torch.utils.data.DataLoader( self.train_imagebatch_stream, @@ -191,6 +192,7 @@ def setup_eval(self): cache_images_type=self.config.cache_images_type, sampling_seed=self.config.train_cameras_sampling_seed, device=self.device, + custom_view_processor=self.custom_view_processor, ) self.eval_image_dataloader = torch.utils.data.DataLoader( self.eval_imagebatch_stream, @@ -235,3 +237,7 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: if len(self.eval_unseen_cameras) == 0: self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] return undistort_view(image_idx, self.eval_dataset, self.config.cache_images_type) + + def custom_view_processor(self, camera, image): + """An API to add latents, metadata, or other further customization an camera-and-image view dataloading process that is parallelized""" + return camera, image diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index b1c94ef069..1d6de532e1 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -251,7 +251,7 @@ def _undistort_image( def undistort_view( idx: int, dataset: InputDataset, image_type: Literal["uint8", "float32"] = "float32" -) -> Dict[str, torch.Tensor]: +) -> Tuple[Cameras, Dict]: """Undistorts an image to one taken by a linear (pinhole) camera model and returns a new Camera with these updated intrinsics Note: this method does not modify the dataset's attributes at all. @@ -596,6 +596,7 @@ def __init__( cache_images_type: Literal["uint8", "float32"] = "float32", sampling_seed: int = 3301, device: Union[torch.device, str] = "cpu", + custom_view_processor: Optional[Callable[[Cameras, Dict], Tuple[Cameras, Dict]]] = None, ): self.input_dataset = input_dataset self.cache_images_type = cache_images_type @@ -630,6 +631,11 @@ def __iter__(self): if camera.metadata is None: camera.metadata = {} camera.metadata["cam_idx"] = idx + + # Apply custom processing if provided + if self.custom_view_processor: + camera, data = self.custom_view_processor(camera, data) + i += 1 yield camera, data From d2785d1806d01fb405486a2a5c3041fd46b05729 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 29 Sep 2024 19:06:57 -0700 Subject: [PATCH 62/78] updating splatfacto config for 4k tests --- nerfstudio/configs/method_configs.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 426132f1e5..83ce51cb2d 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -304,8 +304,6 @@ viewer=ViewerConfig(num_rays_per_chunk=1 << 12), vis="viewer", ) -# -# method_configs["mipnerf"] = TrainerConfig( method_name="mipnerf", pipeline=VanillaPipelineConfig( @@ -607,8 +605,8 @@ pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( _target=ParallelFullImageDatamanager[InputDataset], - # dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), - dataparser=NerfstudioDataParserConfig(load_3D_points=True), + dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), + # dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", use_parallel_dataloader=True, cache_images="disk", From 436af9d2f9bbb23a3213d4c58aa9817f9b06226a Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 29 Sep 2024 19:07:39 -0700 Subject: [PATCH 63/78] updating docstrings to be more descriptive --- nerfstudio/data/datamanagers/full_images_datamanager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index f3a08d4188..e2fedbd09b 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -59,7 +59,7 @@ class FullImageDatamanagerConfig(DataManagerConfig): along with relevant information about camera intrinsics """ cache_images: Literal["cpu", "gpu", "disk"] = "gpu" - """Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device. If "disk", keeps images on disk. """ + """Whether to cache images as pytorch tensors in memory. If "cpu", caches on cpu. If "gpu", caches on device. If "disk", keeps images on disk. """ cache_images_type: Literal["uint8", "float32"] = "float32" """The image type returned from manager, caching images in uint8 saves memory""" max_thread_workers: Optional[int] = None @@ -74,7 +74,7 @@ class FullImageDatamanagerConfig(DataManagerConfig): fps_reset_every: int = 100 """The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every samples from the pool of all training cameras without replacement before a new round of sampling starts.""" - use_parallel_dataloader: bool = cache_images == "disk" + use_parallel_dataloader: bool = False """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" load_from_disk: bool = False """If True, conserves RAM memory by loading images from disk. From dd4daaa70c5d52abc753ff9d98b7acc8d9f02669 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 29 Sep 2024 19:10:04 -0700 Subject: [PATCH 64/78] new datamanager API breaks when setup_eval() has multiple workers, not sure why but single worker will have to do --- nerfstudio/data/datamanagers/parallel_datamanager.py | 3 +-- .../datamanagers/parallel_full_images_datamanager.py | 9 ++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index 925176942e..d97e319639 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -174,8 +174,7 @@ def setup_eval(self): self.eval_ray_dataloader = torch.utils.data.DataLoader( self.eval_raybatchstream, batch_size=1, - num_workers=self.config.dataloader_num_workers, - prefetch_factor=self.config.prefetch_factor, + num_workers=0, shuffle=False, collate_fn=identity_collate, # Our dataset handles batching / collation of rays ) diff --git a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py index 4babc77242..28280c4ec7 100644 --- a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py @@ -182,7 +182,7 @@ def setup_train(self): batch_size=1, num_workers=self.config.dataloader_num_workers, collate_fn=identity_collate, - # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? + # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work ) self.iter_train_image_dataloader = iter(self.train_image_dataloader) @@ -197,9 +197,8 @@ def setup_eval(self): self.eval_image_dataloader = torch.utils.data.DataLoader( self.eval_imagebatch_stream, batch_size=1, - num_workers=self.config.dataloader_num_workers, + num_workers=0, collate_fn=identity_collate, - # pin_memory_device=self.device, ) self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) @@ -238,6 +237,6 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] return undistort_view(image_idx, self.eval_dataset, self.config.cache_images_type) - def custom_view_processor(self, camera, image): + def custom_view_processor(self, camera: Cameras, data: Dict) -> Tuple[Cameras, Dict]: """An API to add latents, metadata, or other further customization an camera-and-image view dataloading process that is parallelized""" - return camera, image + return camera, data From 43c66aeefd7dbb80226342164089f2353ff1f6d2 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Sun, 29 Sep 2024 19:10:39 -0700 Subject: [PATCH 65/78] adding custom_view_processor to ImageBatchStream --- nerfstudio/data/utils/dataloaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 1d6de532e1..bc6a278297 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -602,6 +602,7 @@ def __init__( self.cache_images_type = cache_images_type self.sampling_seed = sampling_seed self.device = device + self.custom_view_processor = custom_view_processor def __iter__(self): # print(self.input_dataset.cameras.device) prints cpu @@ -619,7 +620,7 @@ def __iter__(self): r = random.Random(self.sampling_seed) r.shuffle(worker_indices) i = 0 # i refers to what image index we are outputting: i=0 => we are yielding our first image,camera - print("HELLO", worker_info.id) + while True: if i >= len( worker_indices From 1922566c4d4e7bf7a3d00e3d7df174753b179679 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 30 Sep 2024 23:11:29 -0700 Subject: [PATCH 66/78] reverting full_images_datamanager to main branch --- .../datamanagers/full_images_datamanager.py | 51 ++----------------- 1 file changed, 5 insertions(+), 46 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index eef705eb91..43ccb6d08b 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -44,8 +44,6 @@ from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset -from nerfstudio.data.utils.data_utils import identity_collate -from nerfstudio.data.utils.dataloaders import ImageBatchStream from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -90,7 +88,7 @@ class FullImageDatamanagerConfig(DataManagerConfig): def __post_init__(self): if self.load_from_disk: - self.prefetch_factor = 2 if self.use_parallel_dataloader else None + self.prefetch_factor = 4 if self.use_parallel_dataloader else None if self.use_parallel_dataloader: try: @@ -140,7 +138,6 @@ def __init__( self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") self.train_dataset = self.create_train_dataset() self.eval_dataset = self.create_eval_dataset() - if len(self.train_dataset) > 500 and self.config.cache_images == "gpu": CONSOLE.print( "Train dataset has over 500 images, overriding cache_images to cpu", @@ -325,45 +322,15 @@ def get_datapath(self) -> Path: def setup_train(self): """Sets up the data loaders for training""" - if self.config.use_parallel_dataloader: - self.train_imagebatch_stream = ImageBatchStream( - input_dataset=self.train_dataset, - datamanager_config=self.config, - device=self.device, - ) - self.train_image_dataloader = torch.utils.data.DataLoader( - self.train_imagebatch_stream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - collate_fn=identity_collate, - # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work? - ) - self.iter_train_image_dataloader = iter(self.train_image_dataloader) def setup_eval(self): """Sets up the data loader for evaluation""" - if self.config.use_parallel_dataloader: - self.eval_imagebatch_stream = ImageBatchStream( - input_dataset=self.eval_dataset, - datamanager_config=self.config, - device=self.device, - ) - self.eval_image_dataloader = torch.utils.data.DataLoader( - self.eval_imagebatch_stream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - collate_fn=identity_collate, - # pin_memory_device=self.device, - ) - self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) @property def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: """ Pretends to be the dataloader for evaluation, it returns a list of (camera, data) tuples """ - if self.config.use_parallel_dataloader: - return self.iter_eval_image_dataloader image_indices = [i for i in range(len(self.eval_dataset))] data = [d.copy() for d in self.cached_eval] _cameras = deepcopy(self.eval_dataset.cameras).to(self.device) @@ -394,22 +361,18 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next training batch Returns a Camera instead of raybundle""" - self.train_count += 1 - if self.config.use_parallel_dataloader: - camera, data = next(self.iter_train_image_dataloader)[0] - return camera, data - image_idx = self.train_unseen_cameras.pop(0) # Make sure to re-populate the unseen cameras list if we have exhausted it if len(self.train_unseen_cameras) == 0: self.train_unseen_cameras = self.sample_train_cameras() + data = self.cached_train[image_idx] # We're going to copy to make sure we don't mutate the cached dictionary. # This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335 data = data.copy() data["image"] = data["image"].to(self.device) - assert lDuden(self.train_cameras.shape) == 1, "Assumes single batch dimension" + assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension" camera = self.train_cameras[image_idx : image_idx + 1].to(self.device) if camera.metadata is None: camera.metadata = {} @@ -420,11 +383,6 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next evaluation batch Returns a Camera instead of raybundle""" - self.eval_count += 1 - if self.config.use_parallel_dataloader: - camera, data = next(self.iter_train_image_dataloader)[0] - return camera, data - return self.next_eval_image(step=step) def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: @@ -451,7 +409,7 @@ def _undistort_image( mask = None if camera.camera_type.item() == CameraType.PERSPECTIVE.value: assert distortion_params[3] == 0, ( - "We don't support the 4th Brown parameter for image undistortion, " + "We doesn't support the 4th Brown parameter for image undistortion, " "Only k1, k2, k3, p1, p2 can be non-zero." ) # because OpenCV expects the order of distortion parameters to be (k1, k2, p1, p2, k3), we need to reorder them @@ -620,4 +578,5 @@ def _undistort_image( K = undist_K.numpy() else: raise NotImplementedError("Only perspective and fisheye cameras are supported") + return K, image, mask From beb74beae9017bc48fb358e2c375fcaed5715753 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 30 Sep 2024 23:25:53 -0700 Subject: [PATCH 67/78] removing nn.Module inheritance from Datamanager class --- .../data/datamanagers/base_datamanager.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 330593e19d..377c666355 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -39,10 +39,9 @@ get_args, get_origin, ) -import time + import torch import tyro -from torch import nn from torch.nn import Parameter from torch.utils.data.distributed import DistributedSampler from typing_extensions import TypeVar @@ -56,14 +55,8 @@ from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig -from nerfstudio.data.utils.dataloaders import ( - # CacheDataloader, - RayBatchStream, - FixedIndicesEvalDataloader, - RandIndicesEvalDataloader, -) +from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate -from nerfstudio.data.utils.data_utils import identity_collate from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.model_components.ray_generators import RayGenerator from nerfstudio.utils.misc import IterableWrapper, get_orig_class @@ -116,7 +109,7 @@ class DataManagerConfig(InstantiateConfig): """Process images on GPU for speed at the expense of memory, if True.""" -class DataManager(nn.Module): +class DataManager: """Generic data manager's abstract class This version of the data manager is designed be a monolithic way to load data and latents, @@ -370,23 +363,27 @@ def __post_init__(self): "\nCameraOptimizerConfig has been moved from the DataManager to the Model.\n", style="bold yellow" ) warnings.warn("above message coming from", FutureWarning, stacklevel=3) - + """ These heuristics allow the CPU dataloading bottleneck to equal the GPU bottleneck when training, but can be adjusted Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck. """ if self.load_from_disk: - self.train_num_images_to_sample_from = 50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from - self.train_num_times_to_repeat_images = 10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images + self.train_num_images_to_sample_from = ( + 50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from + ) + self.train_num_times_to_repeat_images = ( + 10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images + ) self.prefetch_factor = self.train_num_times_to_repeat_images if self.use_parallel_dataloader else None - + if self.use_parallel_dataloader: try: torch.multiprocessing.set_start_method("spawn") except RuntimeError: pass self.dataloader_num_workers = 4 if self.dataloader_num_workers == 0 else self.dataloader_num_workers - + TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) From 087cff00ea4d8808793f6f57258c4f45a12b07b1 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 30 Sep 2024 23:29:50 -0700 Subject: [PATCH 68/78] don't need to move datamanger to device anymore since Datamanager is not a subclass of nn.Module --- nerfstudio/pipelines/base_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 7916c732de..7a837fdb58 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -263,7 +263,6 @@ def __init__( pts = self.datamanager.train_dataparser_outputs.metadata["points3D_xyz"] pts_rgb = self.datamanager.train_dataparser_outputs.metadata["points3D_rgb"] seed_pts = (pts, pts_rgb) - self.datamanager.to(device) # TODO(ethan): get rid of scene_bounds from the model assert self.datamanager.train_dataset is not None, "Missing input dataset" From 48e6d15afb8e31f9ad6c6077e69c30f791b03693 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 4 Oct 2024 02:10:08 -0700 Subject: [PATCH 69/78] finished integration test with nerfacto --- tests/test_nerfacto_integration.py | 89 ++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 tests/test_nerfacto_integration.py diff --git a/tests/test_nerfacto_integration.py b/tests/test_nerfacto_integration.py new file mode 100644 index 0000000000..79ee88d242 --- /dev/null +++ b/tests/test_nerfacto_integration.py @@ -0,0 +1,89 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import subprocess +import sys +from typing import Literal + + +def run_command_with_console_output(command, stop_on_output=None): + try: + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # Redirect stderr to stdout + text=True, + bufsize=1, # Line buffered + universal_newlines=True, + ) + + # Read and print output in real-time + for line in process.stdout: + print(line, end="") + sys.stdout.flush() # Ensure output is printed immediately + + if stop_on_output and stop_on_output in line: + print(f"\nDetected '{stop_on_output}'. Stopping the process.") + process.terminate() + break + + return_code = process.wait() + if return_code != 0: + print(f"Command failed with return code {return_code}") + except Exception as e: + print(f"An error occurred: {e}") + + +def run_ns_download_data(scene: Literal["poster, dozer, desolation"]): + command = f"ns-download-data nerfstudio --capture-name={scene}" + run_command_with_console_output(command) + + +def run_ns_train(scene: Literal["poster, dozer, desolation"]): + dataset_path = f"data/nerfstudio/{scene}" + command = f"ns-train nerfacto --data {dataset_path}" + run_command_with_console_output(command, stop_on_output="Checkpoint Directory") + + +def run_ns_eval(scene: Literal["poster, dozer, desolation"]): + timestamp = sorted(os.listdir(f"outputs/{scene}/nerfacto/"))[-1] + config_filename = f"outputs/{scene}/nerfacto/{timestamp}/config.yml" + command = f"ns-eval --load-config {config_filename} --output-path nerfacto_integration_eval.json" + run_command_with_console_output(command) + + with open("nerfacto_integration_eval.json", "r") as f: + results = json.load(f) + + assert results["results"]["psnr"] > 20.0 + assert results["results"]["ssim"] > 0.7 + + +def main(): + scene = "dozer" # You can change this to "poster" or "desolation" + + # print("Starting data download...") + # run_ns_download_data(scene) + + # print("\nStarting training...") + # run_ns_train(scene) + + print("\nStarting evaluation...") + run_ns_eval(scene) + + +if __name__ == "__main__": + main() From 3f1799bd883823e4c82ce0980ceca06b8ac1bac5 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 25 Oct 2024 02:35:15 -0700 Subject: [PATCH 70/78] simplified config variables, integrated the parallelism/disk-data-loading all into one datamanager --- .../datamanagers/full_images_datamanager.py | 278 ++++-------------- tests/test_nerfacto_integration.py | 12 +- 2 files changed, 70 insertions(+), 220 deletions(-) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 43ccb6d08b..6feb91f934 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -29,7 +29,6 @@ from pathlib import Path from typing import Dict, ForwardRef, Generic, List, Literal, Optional, Tuple, Type, Union, cast, get_args, get_origin -import cv2 import fpsample import numpy as np import torch @@ -37,13 +36,14 @@ from torch.nn import Parameter from typing_extensions import assert_never -from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper -from nerfstudio.cameras.cameras import Cameras, CameraType +from nerfstudio.cameras.cameras import Cameras from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, TDataset from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig from nerfstudio.data.datasets.base_dataset import InputDataset +from nerfstudio.data.utils.data_utils import identity_collate +from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -57,7 +57,10 @@ class FullImageDatamanagerConfig(DataManagerConfig): along with relevant information about camera intrinsics """ cache_images: Literal["cpu", "gpu", "disk"] = "gpu" - """Whether to cache images as pytorch tensors in memory. If "cpu", caches on cpu. If "gpu", caches on device. If "disk", keeps images on disk. """ + """Where to cache images in memory. + - If "cpu", caches images on cpu RAM as pytorch tensors. + - If "gpu", caches images on device as pytorch tensors. + - If "disk", keeps images on disk which conserves memory. Datamanager will use parallel dataloader""" cache_images_type: Literal["uint8", "float32"] = "float32" """The image type returned from manager, caching images in uint8 saves memory""" max_thread_workers: Optional[int] = None @@ -72,25 +75,18 @@ class FullImageDatamanagerConfig(DataManagerConfig): fps_reset_every: int = 100 """The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every samples from the pool of all training cameras without replacement before a new round of sampling starts.""" - use_parallel_dataloader: bool = False - """Supports datasets that do not fit in system RAM and allows parallelization of the dataloading process with multiple workers.""" - load_from_disk: bool = False - """If True, conserves RAM memory by loading images from disk. - If False, caches all the images as tensors to RAM and loads from RAM.""" dataloader_num_workers: int = 0 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = None + prefetch_factor: int = 0 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" cache_compressed_images: bool = False """If True, cache raw image files as byte strings to RAM.""" def __post_init__(self): - if self.load_from_disk: - self.prefetch_factor = 4 if self.use_parallel_dataloader else None - - if self.use_parallel_dataloader: + if self.cache_images == "disk": + self.prefetch_factor = 4 try: torch.multiprocessing.set_start_method("spawn") except RuntimeError: @@ -134,27 +130,21 @@ def __init__( if test_mode == "inference": self.dataparser.downscale_factor = 1 # Avoid opening images self.includes_time = self.dataparser.includes_time - self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") self.train_dataset = self.create_train_dataset() self.eval_dataset = self.create_eval_dataset() + if len(self.train_dataset) > 500 and self.config.cache_images == "gpu": CONSOLE.print( "Train dataset has over 500 images, overriding cache_images to cpu", style="bold yellow", ) self.config.cache_images = "cpu" - self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device - if self.config.masks_on_gpu is True: - self.exclude_batch_keys_from_device.remove("mask") - if self.config.images_on_gpu is True: - self.exclude_batch_keys_from_device.remove("image") # Some logic to make sure we sample every camera in equal amounts self.train_unseen_cameras = self.sample_train_cameras() self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] assert len(self.train_unseen_cameras) > 0, "No data found in dataset" - super().__init__() def sample_train_cameras(self): @@ -186,7 +176,6 @@ def sample_train_cameras(self): ) n = num_train_cameras kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(data, n, h=3) - self.train_unsampled_epoch_count += 1 self.train_unsampled_epoch_count[kdline_fps_samples_idx] = 0 return kdline_fps_samples_idx.tolist() @@ -197,19 +186,20 @@ def sample_train_cameras(self): def cached_train(self) -> List[Dict[str, torch.Tensor]]: """Get the training images. Will load and undistort the images the first time this (cached) property is accessed.""" + assert self.config.cache_images != "disk", "Can not call _load_images() with `disk` as input" return self._load_images("train", cache_images_device=self.config.cache_images) @cached_property def cached_eval(self) -> List[Dict[str, torch.Tensor]]: """Get the eval images. Will load and undistort the images the first time this (cached) property is accessed.""" + assert self.config.cache_images != "disk", "Can not call _load_images() with `disk` as input" return self._load_images("eval", cache_images_device=self.config.cache_images) def _load_images( self, split: Literal["train", "eval"], cache_images_device: Literal["cpu", "gpu"] ) -> List[Dict[str, torch.Tensor]]: undistorted_images: List[Dict[str, torch.Tensor]] = [] - # Which dataset? if split == "train": dataset = self.train_dataset @@ -257,7 +247,6 @@ def undistort_idx(idx: int) -> Dict[str, torch.Tensor]: total=len(dataset), ) ) - # Move to device. if cache_images_device == "gpu": for cache in undistorted_images: @@ -275,7 +264,6 @@ def undistort_idx(idx: int) -> Dict[str, torch.Tensor]: self.train_cameras = self.train_dataset.cameras else: assert_never(cache_images_device) - return undistorted_images def create_train_dataset(self) -> TDataset: @@ -283,6 +271,7 @@ def create_train_dataset(self) -> TDataset: return self.dataset_type( dataparser_outputs=self.train_dataparser_outputs, scale_factor=self.config.camera_res_scale_factor, + cache_compressed_images=self.config.cache_compressed_images, ) def create_eval_dataset(self) -> TDataset: @@ -290,6 +279,7 @@ def create_eval_dataset(self) -> TDataset: return self.dataset_type( dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), scale_factor=self.config.camera_res_scale_factor, + cache_compressed_images=self.config.cache_compressed_images, ) @cached_property @@ -301,7 +291,6 @@ def dataset_type(self) -> Type[TDataset]: return default if orig_class is not None and get_origin(orig_class) is FullImageDatamanager: return get_args(orig_class)[0] - # For inherited classes, we need to find the correct type to instantiate for base in getattr(self, "__orig_bases__", []): if get_origin(base) is FullImageDatamanager: @@ -322,15 +311,48 @@ def get_datapath(self) -> Path: def setup_train(self): """Sets up the data loaders for training""" + if self.config.cache_images == "disk": + self.train_imagebatch_stream = ImageBatchStream( + input_dataset=self.train_dataset, + cache_images_type=self.config.cache_images_type, + sampling_seed=self.config.train_cameras_sampling_seed, + device=self.device, + custom_view_processor=self.custom_view_processor, + ) + self.train_image_dataloader = torch.utils.data.DataLoader( + self.train_imagebatch_stream, + batch_size=1, + num_workers=self.config.dataloader_num_workers, + collate_fn=identity_collate, + # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work + ) + self.iter_train_image_dataloader = iter(self.train_image_dataloader) def setup_eval(self): """Sets up the data loader for evaluation""" + if self.config.cache_images == "disk": + self.eval_imagebatch_stream = ImageBatchStream( + input_dataset=self.eval_dataset, + cache_images_type=self.config.cache_images_type, + sampling_seed=self.config.train_cameras_sampling_seed, + device=self.device, + custom_view_processor=self.custom_view_processor, + ) + self.eval_image_dataloader = torch.utils.data.DataLoader( + self.eval_imagebatch_stream, + batch_size=1, + num_workers=0, + collate_fn=identity_collate, + ) + self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) @property def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: """ Pretends to be the dataloader for evaluation, it returns a list of (camera, data) tuples """ + if self.config.cache_images == "disk": + return self.iter_eval_image_dataloader image_indices = [i for i in range(len(self.eval_dataset))] data = [d.copy() for d in self.cached_eval] _cameras = deepcopy(self.eval_dataset.cameras).to(self.device) @@ -350,17 +372,17 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: def get_train_rays_per_batch(self): """Returns resolution of the image returned from datamanager.""" - if len(self.cached_train) != 0: - h = self.cached_train[0]["image"].shape[0] - w = self.cached_train[0]["image"].shape[1] - return h * w - else: - return 800 * 800 + camera = self.train_dataset.cameras[0].reshape(()) + return camera.width.item() * camera.height.item() def next_train(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next training batch - Returns a Camera instead of raybundle""" + self.train_count += 1 + if self.config.cache_images == "disk": + camera, data = next(self.iter_train_image_dataloader)[0] + return camera, data + image_idx = self.train_unseen_cameras.pop(0) # Make sure to re-populate the unseen cameras list if we have exhausted it if len(self.train_unseen_cameras) == 0: @@ -372,6 +394,7 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: data = data.copy() data["image"] = data["image"].to(self.device) + assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension" assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension" camera = self.train_cameras[image_idx : image_idx + 1].to(self.device) if camera.metadata is None: @@ -381,15 +404,17 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: def next_eval(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next evaluation batch - Returns a Camera instead of raybundle""" + self.eval_count += 1 + if self.config.cache_images == "disk": + camera, data = next(self.iter_train_image_dataloader)[0] + return camera, data + return self.next_eval_image(step=step) def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next evaluation batch - Returns a Camera instead of raybundle - TODO: Make sure this logic is consistent with the vanilladatamanager""" image_idx = self.eval_unseen_cameras.pop(random.randint(0, len(self.eval_unseen_cameras) - 1)) # Make sure to re-populate the unseen cameras list if we have exhausted it @@ -402,181 +427,6 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device) return camera, data - -def _undistort_image( - camera: Cameras, distortion_params: np.ndarray, data: dict, image: np.ndarray, K: np.ndarray -) -> Tuple[np.ndarray, np.ndarray, Optional[torch.Tensor]]: - mask = None - if camera.camera_type.item() == CameraType.PERSPECTIVE.value: - assert distortion_params[3] == 0, ( - "We doesn't support the 4th Brown parameter for image undistortion, " - "Only k1, k2, k3, p1, p2 can be non-zero." - ) - # because OpenCV expects the order of distortion parameters to be (k1, k2, p1, p2, k3), we need to reorder them - # see https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html - distortion_params = np.array( - [ - distortion_params[0], - distortion_params[1], - distortion_params[4], - distortion_params[5], - distortion_params[2], - distortion_params[3], - 0, - 0, - ] - ) - # because OpenCV expects the pixel coord to be top-left, we need to shift the principal point by 0.5 - # see https://github.com/nerfstudio-project/nerfstudio/issues/3048 - K[0, 2] = K[0, 2] - 0.5 - K[1, 2] = K[1, 2] - 0.5 - if np.any(distortion_params): - newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0) - image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore - else: - newK = K - roi = 0, 0, image.shape[1], image.shape[0] - # crop the image and update the intrinsics accordingly - x, y, w, h = roi - image = image[y : y + h, x : x + w] - # update the principal point based on our cropped region of interest (ROI) - newK[0, 2] -= x - newK[1, 2] -= y - if "depth_image" in data: - data["depth_image"] = data["depth_image"][y : y + h, x : x + w] - if "mask" in data: - mask = data["mask"].numpy() - mask = mask.astype(np.uint8) * 255 - if np.any(distortion_params): - mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore - mask = mask[y : y + h, x : x + w] - mask = torch.from_numpy(mask).bool() - if len(mask.shape) == 2: - mask = mask[:, :, None] - newK[0, 2] = newK[0, 2] + 0.5 - newK[1, 2] = newK[1, 2] + 0.5 - K = newK - - elif camera.camera_type.item() == CameraType.FISHEYE.value: - K[0, 2] = K[0, 2] - 0.5 - K[1, 2] = K[1, 2] - 0.5 - distortion_params = np.array( - [distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]] - ) - newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( - K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0 - ) - map1, map2 = cv2.fisheye.initUndistortRectifyMap( - K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1 - ) - # and then remap: - image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR) - if "mask" in data: - mask = data["mask"].numpy() - mask = mask.astype(np.uint8) * 255 - mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK) - mask = torch.from_numpy(mask).bool() - if len(mask.shape) == 2: - mask = mask[:, :, None] - newK[0, 2] = newK[0, 2] + 0.5 - newK[1, 2] = newK[1, 2] + 0.5 - K = newK - elif camera.camera_type.item() == CameraType.FISHEYE624.value: - fisheye624_params = torch.cat( - [camera.fx, camera.fy, camera.cx, camera.cy, torch.from_numpy(distortion_params)], dim=0 - ) - assert fisheye624_params.shape == (16,) - assert ( - "mask" not in data - and camera.metadata is not None - and "fisheye_crop_radius" in camera.metadata - and isinstance(camera.metadata["fisheye_crop_radius"], float) - ) - fisheye_crop_radius = camera.metadata["fisheye_crop_radius"] - - # Approximate the FOV of the unmasked region of the camera. - upper, lower, left, right = fisheye624_unproject_helper( - torch.tensor( - [ - [camera.cx, camera.cy - fisheye_crop_radius], - [camera.cx, camera.cy + fisheye_crop_radius], - [camera.cx - fisheye_crop_radius, camera.cy], - [camera.cx + fisheye_crop_radius, camera.cy], - ], - dtype=torch.float32, - )[None], - params=fisheye624_params[None], - ).squeeze(dim=0) - fov_radians = torch.max( - torch.acos(torch.sum(upper * lower / torch.linalg.norm(upper) / torch.linalg.norm(lower))), - torch.acos(torch.sum(left * right / torch.linalg.norm(left) / torch.linalg.norm(right))), - ) - - # Heuristics to determine parameters of an undistorted image. - undist_h = int(fisheye_crop_radius * 2) - undist_w = int(fisheye_crop_radius * 2) - undistort_focal = undist_h / (2 * torch.tan(fov_radians / 2.0)) - undist_K = torch.eye(3) - undist_K[0, 0] = undistort_focal # fx - undist_K[1, 1] = undistort_focal # fy - undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0). - undist_K[1, 2] = (undist_h - 1) / 2.0 # cy - - # Undistorted 2D coordinates -> rays -> reproject to distorted UV coordinates. - undist_uv_homog = torch.stack( - [ - *torch.meshgrid( - torch.arange(undist_w, dtype=torch.float32), - torch.arange(undist_h, dtype=torch.float32), - ), - torch.ones((undist_w, undist_h), dtype=torch.float32), - ], - dim=-1, - ) - assert undist_uv_homog.shape == (undist_w, undist_h, 3) - dist_uv = ( - fisheye624_project( - xyz=( - torch.einsum( - "ij,bj->bi", - torch.linalg.inv(undist_K), - undist_uv_homog.reshape((undist_w * undist_h, 3)), - )[None] - ), - params=fisheye624_params[None, :], - ) - .reshape((undist_w, undist_h, 2)) - .numpy() - ) - map1 = dist_uv[..., 1] - map2 = dist_uv[..., 0] - - # Use correspondence to undistort image. - image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR) - - # Compute undistorted mask as well. - dist_h = camera.height.item() - dist_w = camera.width.item() - mask = np.mgrid[:dist_h, :dist_w] - mask[0, ...] -= dist_h // 2 - mask[1, ...] -= dist_w // 2 - mask = np.linalg.norm(mask, axis=0) < fisheye_crop_radius - mask = torch.from_numpy( - cv2.remap( - mask.astype(np.uint8) * 255, - map1, - map2, - interpolation=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_CONSTANT, - borderValue=0, - ) - / 255.0 - ).bool()[..., None] - if len(mask.shape) == 2: - mask = mask[:, :, None] - assert mask.shape == (undist_h, undist_w, 1) - K = undist_K.numpy() - else: - raise NotImplementedError("Only perspective and fisheye cameras are supported") - - return K, image, mask + def custom_view_processor(self, camera: Cameras, data: Dict) -> Tuple[Cameras, Dict]: + """An API to add latents, metadata, or other further customization an camera-and-image view dataloading process that is parallelized""" + return camera, data diff --git a/tests/test_nerfacto_integration.py b/tests/test_nerfacto_integration.py index 79ee88d242..cd7786e35e 100644 --- a/tests/test_nerfacto_integration.py +++ b/tests/test_nerfacto_integration.py @@ -68,18 +68,18 @@ def run_ns_eval(scene: Literal["poster, dozer, desolation"]): with open("nerfacto_integration_eval.json", "r") as f: results = json.load(f) - assert results["results"]["psnr"] > 20.0 - assert results["results"]["ssim"] > 0.7 + assert results["results"]["psnr"] > 20.0, "PSNR was lower than 20" + assert results["results"]["ssim"] > 0.7, "SSIM was lower than 0.7" def main(): scene = "dozer" # You can change this to "poster" or "desolation" - # print("Starting data download...") - # run_ns_download_data(scene) + print("Starting data download...") + run_ns_download_data(scene) - # print("\nStarting training...") - # run_ns_train(scene) + print("\nStarting training...") + run_ns_train(scene) print("\nStarting evaluation...") run_ns_eval(scene) From f46aa4277ae59aec131b660080e7ad03843734c4 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 25 Oct 2024 02:36:54 -0700 Subject: [PATCH 71/78] updated the splatfacto config to be simpler with the dataloading and now uses FullImageDatamanager (which has been changed) --- nerfstudio/configs/method_configs.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index b544883c15..76cd7bce03 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -134,9 +134,12 @@ mixed_precision=True, pipeline=VanillaPipelineConfig( datamanager=VanillaDataManagerConfig( + _target=ParallelDataManager[InputDataset], dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=8192, eval_num_rays_per_batch=4096, + load_from_disk=True, + use_parallel_dataloader=True, ), model=NerfactoModelConfig( eval_num_rays_per_chunk=1 << 15, @@ -606,11 +609,9 @@ mixed_precision=False, pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( - _target=ParallelFullImageDatamanager[InputDataset], dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), # dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", - use_parallel_dataloader=True, cache_images="disk", ), model=SplatfactoModelConfig(), @@ -667,8 +668,11 @@ mixed_precision=False, pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( - dataparser=NerfstudioDataParserConfig(load_3D_points=True), + _target=ParallelFullImageDatamanager[InputDataset], + dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), + # dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", + cache_images="disk", ), model=SplatfactoModelConfig( cull_alpha_thresh=0.005, From 5aa51fba37fc718967de78b4e91aefc2e9839154 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 25 Oct 2024 02:39:14 -0700 Subject: [PATCH 72/78] style checks and some cleanup --- .../data/datamanagers/parallel_datamanager.py | 3 +-- .../parallel_full_images_datamanager.py | 5 ----- nerfstudio/data/utils/dataloaders.py | 13 ++++++------- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index d97e319639..031fde5d9c 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -179,12 +179,11 @@ def setup_eval(self): collate_fn=identity_collate, # Our dataset handles batching / collation of rays ) self.iter_eval_raybundles = iter(self.eval_ray_dataloader) - self.image_eval_dataloader = RandIndicesEvalDataloader( + self.image_eval_dataloader = RandIndicesEvalDataloader( # this is used for ns-eval input_dataset=self.eval_dataset, device=self.device, num_workers=self.world_size * 4, ) - # this is used for ns-eval self.fixed_indices_eval_dataloader = FixedIndicesEvalDataloader( input_dataset=self.eval_dataset, device=self.device, diff --git a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py index 28280c4ec7..94ec321c9b 100644 --- a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py @@ -76,11 +76,6 @@ def __init__( style="bold yellow", ) self.config.cache_images = "cpu" - self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device - if self.config.masks_on_gpu is True: - self.exclude_batch_keys_from_device.remove("mask") - if self.config.images_on_gpu is True: - self.exclude_batch_keys_from_device.remove("image") # Some logic to make sure we sample every camera in equal amounts self.train_unseen_cameras = self.sample_train_cameras() diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index bc6a278297..dd2c81e730 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -399,7 +399,6 @@ def __iter__(self): import math from torch.utils.data import Dataset -from tqdm.auto import tqdm class RayBatchStream(torch.utils.data.IterableDataset): @@ -483,8 +482,9 @@ def _get_pixel_sampler(self, dataset: Dataset, num_rays_per_batch: int) -> Pixel def _get_batch_list(self, indices=None): """Returns a list representing a single batch from the dataset attribute. Each item of the list is a dictionary with dict_keys(['image_idx', 'image']) representing 1 image. - This function is used to sample and load images from disk/RAM and is only called in _get_collated_batch - The length of the list is equal to the (# of training images) / (num_workers)""" + This function is used to sample and load images from disk/RAM and is only called in _get_collated_batch() + The length of the list is equal to the (# of training images) / (num_workers) + """ assert isinstance(self.input_dataset, Sized) if indices is None: @@ -504,7 +504,7 @@ def _get_batch_list(self, indices=None): for idx in indices: res = executor.submit(self.input_dataset.__getitem__, idx) results.append(res) - results = tqdm(results) # this is temporary and will be removed in the final push + # results = tqdm(results) # this is temporary and will be removed in the final push for res in results: batch_list.append(res.result()) @@ -622,9 +622,8 @@ def __iter__(self): i = 0 # i refers to what image index we are outputting: i=0 => we are yielding our first image,camera while True: - if i >= len( - worker_indices - ): # if we've iterated through all the worker's partition of images, we need to reshuffle + if i >= len(worker_indices): + # if we've iterated through all the worker's partition of images, we need to reshuffle r.shuffle(worker_indices) i = 0 idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve From ec3c12a5d4ccbecd7ae8171401b65c8f24d50502 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 25 Oct 2024 02:44:01 -0700 Subject: [PATCH 73/78] new splatfacto test, cleaning up nerfacto integration test --- tests/test_nerfacto_integration.py | 8 ++-- tests/test_splatfacto_integration.py | 55 ++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 4 deletions(-) create mode 100644 tests/test_splatfacto_integration.py diff --git a/tests/test_nerfacto_integration.py b/tests/test_nerfacto_integration.py index cd7786e35e..64b2e528fc 100644 --- a/tests/test_nerfacto_integration.py +++ b/tests/test_nerfacto_integration.py @@ -48,18 +48,18 @@ def run_command_with_console_output(command, stop_on_output=None): print(f"An error occurred: {e}") -def run_ns_download_data(scene: Literal["poster, dozer, desolation"]): +def run_ns_download_data(scene: Literal["poster", "dozer", "desolation"]): command = f"ns-download-data nerfstudio --capture-name={scene}" run_command_with_console_output(command) -def run_ns_train(scene: Literal["poster, dozer, desolation"]): +def run_ns_train_nerfacto(scene: Literal["poster", "dozer", "desolation"]): dataset_path = f"data/nerfstudio/{scene}" command = f"ns-train nerfacto --data {dataset_path}" run_command_with_console_output(command, stop_on_output="Checkpoint Directory") -def run_ns_eval(scene: Literal["poster, dozer, desolation"]): +def run_ns_eval(scene: Literal["poster", "dozer", "desolation"]): timestamp = sorted(os.listdir(f"outputs/{scene}/nerfacto/"))[-1] config_filename = f"outputs/{scene}/nerfacto/{timestamp}/config.yml" command = f"ns-eval --load-config {config_filename} --output-path nerfacto_integration_eval.json" @@ -79,7 +79,7 @@ def main(): run_ns_download_data(scene) print("\nStarting training...") - run_ns_train(scene) + run_ns_train_nerfacto(scene) print("\nStarting evaluation...") run_ns_eval(scene) diff --git a/tests/test_splatfacto_integration.py b/tests/test_splatfacto_integration.py new file mode 100644 index 0000000000..c798f914fc --- /dev/null +++ b/tests/test_splatfacto_integration.py @@ -0,0 +1,55 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Literal + +from test_nerfacto_integration import run_command_with_console_output, run_ns_download_data + + +def run_ns_train_splatfacto(scene: Literal["poster", "dozer", "desolation"]): + dataset_path = f"data/nerfstudio/{scene}" + command = f"ns-train splatfacto --data {dataset_path}" + run_command_with_console_output(command, stop_on_output="Checkpoint Directory") + + +def run_ns_eval(scene: Literal["poster", "dozer", "desolation"]): + timestamp = sorted(os.listdir(f"outputs/{scene}/splatfacto/"))[-1] + config_filename = f"outputs/{scene}/splatfacto/{timestamp}/config.yml" + command = f"ns-eval --load-config {config_filename} --output-path splatfacto_integration_eval.json" + run_command_with_console_output(command) + + with open("splatfacto_integration_eval.json", "r") as f: + results = json.load(f) + + assert results["results"]["psnr"] > 20.0, "PSNR was lower than 20" + assert results["results"]["ssim"] > 0.7, "SSIM was lower than 0.7" + + +def main(): + scene = "dozer" # You can change this to "poster" or "desolation" + + print("Starting data download...") + run_ns_download_data(scene) + + print("\nStarting training...") + run_ns_train_splatfacto(scene) + + print("\nStarting evaluation...") + run_ns_eval(scene) + + +if __name__ == "__main__": + main() From 82bc5b2a0cac39846462c3215d152d17231719ad Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 25 Oct 2024 23:56:32 -0700 Subject: [PATCH 74/78] removing redundant parallel_full_images_datamaanger, as the OG full_image_datamanager now has full parallelized support --- .../parallel_full_images_datamanager.py | 237 ------------------ 1 file changed, 237 deletions(-) delete mode 100644 nerfstudio/data/datamanagers/parallel_full_images_datamanager.py diff --git a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py b/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py deleted file mode 100644 index 94ec321c9b..0000000000 --- a/nerfstudio/data/datamanagers/parallel_full_images_datamanager.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Parallel data manager that outputs cameras / images instead of raybundles. -""" - -from __future__ import annotations - -import random -from functools import cached_property -from pathlib import Path -from typing import Dict, ForwardRef, Generic, List, Literal, Tuple, Type, Union, cast, get_args, get_origin - -import fpsample -import numpy as np -import torch -from torch.nn import Parameter - -from nerfstudio.cameras.cameras import Cameras -from nerfstudio.data.datamanagers.base_datamanager import DataManager, TDataset -from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig -from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs -from nerfstudio.data.datasets.base_dataset import InputDataset -from nerfstudio.data.utils.data_utils import identity_collate -from nerfstudio.data.utils.dataloaders import ImageBatchStream, undistort_view -from nerfstudio.utils.misc import get_orig_class -from nerfstudio.utils.rich_utils import CONSOLE - - -class ParallelFullImageDatamanager(DataManager, Generic[TDataset]): - def __init__( - self, - config: FullImageDatamanagerConfig, - device: Union[torch.device, str] = "cpu", - test_mode: Literal["test", "val", "inference"] = "val", - world_size: int = 1, - local_rank: int = 0, - **kwargs, - ): - self.config = config - self.device = device - self.world_size = world_size - self.local_rank = local_rank - self.sampler = None - self.test_mode = test_mode - self.test_split = "test" if test_mode in ["test", "inference"] else "val" - self.dataparser_config = self.config.dataparser - if self.config.data is not None: - self.config.dataparser.data = Path(self.config.data) - else: - self.config.data = self.config.dataparser.data - self.dataparser = self.dataparser_config.setup() - if test_mode == "inference": - self.dataparser.downscale_factor = 1 # Avoid opening images - self.includes_time = self.dataparser.includes_time - - self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") - self.train_dataset = self.create_train_dataset() - self.eval_dataset = self.create_eval_dataset() - - if len(self.train_dataset) > 500 and self.config.cache_images == "gpu": - CONSOLE.print( - "Train dataset has over 500 images, overriding cache_images to cpu", - style="bold yellow", - ) - self.config.cache_images = "cpu" - - # Some logic to make sure we sample every camera in equal amounts - self.train_unseen_cameras = self.sample_train_cameras() - self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] - assert len(self.train_unseen_cameras) > 0, "No data found in dataset" - - super().__init__() - - def sample_train_cameras(self): - """Return a list of camera indices sampled using the strategy specified by - self.config.train_cameras_sampling_strategy""" - num_train_cameras = len(self.train_dataset) - if self.config.train_cameras_sampling_strategy == "random": - if not hasattr(self, "random_generator"): - self.random_generator = random.Random(self.config.train_cameras_sampling_seed) - indices = list(range(num_train_cameras)) - self.random_generator.shuffle(indices) - return indices - elif self.config.train_cameras_sampling_strategy == "fps": - if not hasattr(self, "train_unsampled_epoch_count"): - np.random.seed(self.config.train_cameras_sampling_seed) # fix random seed of fpsample - self.train_unsampled_epoch_count = np.zeros(num_train_cameras) - camera_origins = self.train_dataset.cameras.camera_to_worlds[..., 3].numpy() - # We concatenate camera origins with weighted train_unsampled_epoch_count because we want to - # increase the chance to sample camera that hasn't been sampled in consecutive epochs previously. - # We assume the camera origins are also rescaled, so the weight 0.1 is relative to the scale of scene - data = np.concatenate( - (camera_origins, 0.1 * np.expand_dims(self.train_unsampled_epoch_count, axis=-1)), axis=-1 - ) - n = self.config.fps_reset_every - if num_train_cameras < n: - CONSOLE.log( - f"num_train_cameras={num_train_cameras} is smaller than fps_reset_ever={n}, the behavior of " - "camera sampler will be very similar to sampling random without replacement (default setting)." - ) - n = num_train_cameras - kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(data, n, h=3) - - self.train_unsampled_epoch_count += 1 - self.train_unsampled_epoch_count[kdline_fps_samples_idx] = 0 - return kdline_fps_samples_idx.tolist() - else: - raise ValueError(f"Unknown train camera sampling strategy: {self.config.train_cameras_sampling_strategy}") - - def create_train_dataset(self) -> TDataset: - """Sets up the data loaders for training""" - return self.dataset_type( - dataparser_outputs=self.train_dataparser_outputs, - scale_factor=self.config.camera_res_scale_factor, - cache_compressed_images=self.config.cache_compressed_images, - ) - - def create_eval_dataset(self) -> TDataset: - """Sets up the data loaders for evaluation""" - return self.dataset_type( - dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), - scale_factor=self.config.camera_res_scale_factor, - cache_compressed_images=self.config.cache_compressed_images, - ) - - @cached_property - def dataset_type(self) -> Type[TDataset]: - """Returns the dataset type passed as the generic argument""" - default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore - orig_class: Type[ParallelFullImageDatamanager] = get_orig_class(self, default=None) # type: ignore - if type(self) is ParallelFullImageDatamanager and orig_class is None: - return default - if orig_class is not None and get_origin(orig_class) is ParallelFullImageDatamanager: - return get_args(orig_class)[0] - - # For inherited classes, we need to find the correct type to instantiate - for base in getattr(self, "__orig_bases__", []): - if get_origin(base) is ParallelFullImageDatamanager: - for value in get_args(base): - if isinstance(value, ForwardRef): - if value.__forward_evaluated__: - value = value.__forward_value__ - elif value.__forward_module__ is None: - value.__forward_module__ = type(self).__module__ - value = getattr(value, "_evaluate")(None, None, set()) - assert isinstance(value, type) - if issubclass(value, InputDataset): - return cast(Type[TDataset], value) - return default - - def get_datapath(self) -> Path: - return self.config.dataparser.data - - def setup_train(self): - self.train_imagebatch_stream = ImageBatchStream( - input_dataset=self.train_dataset, - cache_images_type=self.config.cache_images_type, - sampling_seed=self.config.train_cameras_sampling_seed, - device=self.device, - custom_view_processor=self.custom_view_processor, - ) - self.train_image_dataloader = torch.utils.data.DataLoader( - self.train_imagebatch_stream, - batch_size=1, - num_workers=self.config.dataloader_num_workers, - collate_fn=identity_collate, - # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work - ) - self.iter_train_image_dataloader = iter(self.train_image_dataloader) - - def setup_eval(self): - self.eval_imagebatch_stream = ImageBatchStream( - input_dataset=self.eval_dataset, - cache_images_type=self.config.cache_images_type, - sampling_seed=self.config.train_cameras_sampling_seed, - device=self.device, - custom_view_processor=self.custom_view_processor, - ) - self.eval_image_dataloader = torch.utils.data.DataLoader( - self.eval_imagebatch_stream, - batch_size=1, - num_workers=0, - collate_fn=identity_collate, - ) - self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) - - @property - def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: - return self.iter_eval_image_dataloader - - def get_param_groups(self) -> Dict[str, List[Parameter]]: - """Get the param groups for the data manager. - Returns: - A list of dictionaries containing the data manager's param groups. - """ - return {} - - def get_train_rays_per_batch(self): - # TODO: fix this to be the resolution of the last image rendered - return 800 * 800 - - def next_train(self, step: int) -> Tuple[Cameras, Dict]: - self.train_count += 1 - camera, data = next(self.iter_train_image_dataloader)[0] - return camera, data - - def next_eval(self, step: int) -> Tuple[Cameras, Dict]: - self.eval_count += 1 - camera, data = next(self.iter_train_image_dataloader)[0] - return camera, data - - def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: - """Returns the next evaluation batch - - Returns a Camera instead of raybundle""" - image_idx = self.eval_unseen_cameras.pop(random.randint(0, len(self.eval_unseen_cameras) - 1)) - # Make sure to re-populate the unseen cameras list if we have exhausted it - if len(self.eval_unseen_cameras) == 0: - self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] - return undistort_view(image_idx, self.eval_dataset, self.config.cache_images_type) - - def custom_view_processor(self, camera: Cameras, data: Dict) -> Tuple[Cameras, Dict]: - """An API to add latents, metadata, or other further customization an camera-and-image view dataloading process that is parallelized""" - return camera, data From bbb5473c1ee4c91a9b392f99892b9bf5acf5f5d6 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 28 Oct 2024 03:59:37 -0700 Subject: [PATCH 75/78] ruff linting and pyright fixing --- .../data/datamanagers/base_datamanager.py | 1 + .../datamanagers/full_images_datamanager.py | 5 +++-- .../data/datamanagers/parallel_datamanager.py | 5 +++-- nerfstudio/data/datasets/base_dataset.py | 19 ++++++++++++------- nerfstudio/data/utils/data_utils.py | 13 ++++++++----- nerfstudio/data/utils/dataloaders.py | 13 ++++--------- 6 files changed, 31 insertions(+), 25 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 377c666355..b968fc25bf 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -162,6 +162,7 @@ class DataManager: train_sampler: Optional[DistributedSampler] = None eval_sampler: Optional[DistributedSampler] = None includes_time: bool = False + test_mode: Literal["test", "val", "inference"] = "val" def __init__(self): """Constructor for the DataManager class. diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 6feb91f934..e44d616977 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -34,6 +34,7 @@ import torch from rich.progress import track from torch.nn import Parameter +from torch.utils.data import DataLoader from typing_extensions import assert_never from nerfstudio.cameras.cameras import Cameras @@ -319,7 +320,7 @@ def setup_train(self): device=self.device, custom_view_processor=self.custom_view_processor, ) - self.train_image_dataloader = torch.utils.data.DataLoader( + self.train_image_dataloader = DataLoader( self.train_imagebatch_stream, batch_size=1, num_workers=self.config.dataloader_num_workers, @@ -338,7 +339,7 @@ def setup_eval(self): device=self.device, custom_view_processor=self.custom_view_processor, ) - self.eval_image_dataloader = torch.utils.data.DataLoader( + self.eval_image_dataloader = DataLoader( self.eval_imagebatch_stream, batch_size=1, num_workers=0, diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index 031fde5d9c..d78ebc3855 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -24,6 +24,7 @@ import torch from torch.nn import Parameter +from torch.utils.data import DataLoader from nerfstudio.cameras.cameras import Cameras from nerfstudio.cameras.rays import RayBundle @@ -150,7 +151,7 @@ def setup_train(self): load_from_disk=self.config.load_from_disk, custom_ray_processor=self.custom_ray_processor, ) - self.train_ray_dataloader = torch.utils.data.DataLoader( + self.train_ray_dataloader = DataLoader( self.train_raybatchstream, batch_size=1, num_workers=self.config.dataloader_num_workers, @@ -171,7 +172,7 @@ def setup_eval(self): load_from_disk=True, custom_ray_processor=self.custom_ray_processor, ) - self.eval_ray_dataloader = torch.utils.data.DataLoader( + self.eval_ray_dataloader = DataLoader( self.eval_raybatchstream, batch_size=1, num_workers=0, diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index 19d977f890..279988f7f2 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -18,8 +18,8 @@ from __future__ import annotations -from copy import deepcopy import io +from copy import deepcopy from pathlib import Path from typing import Dict, List, Literal @@ -29,12 +29,13 @@ from jaxtyping import Float, UInt8 from PIL import Image from torch import Tensor +from torch.profiler import record_function from torch.utils.data import Dataset from nerfstudio.cameras.cameras import Cameras from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs from nerfstudio.data.utils.data_utils import get_image_mask_tensor_from_path, pil_to_numpy -from torch.profiler import record_function + class InputDataset(Dataset): """Dataset that returns images. @@ -47,7 +48,9 @@ class InputDataset(Dataset): exclude_batch_keys_from_device: List[str] = ["image", "mask"] cameras: Cameras - def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_compressed_images: bool = False): + def __init__( + self, dataparser_outputs: DataparserOutputs, scale_factor: float = 1.0, cache_compressed_images: bool = False + ): super().__init__() self._dataparser_outputs = dataparser_outputs self.scale_factor = scale_factor @@ -62,11 +65,11 @@ def __init__(self, dataparser_outputs: DataparserOutputs, scale_factor: float = self.binary_images = [] self.binary_masks = [] for image_filename in self._dataparser_outputs.image_filenames: - with open(image_filename, 'rb') as f: + with open(image_filename, "rb") as f: self.binary_images.append(io.BytesIO(f.read())) if self._dataparser_outputs.mask_filenames is not None: for mask_filename in self._dataparser_outputs.mask_filenames: - with open(mask_filename, 'rb') as f: + with open(mask_filename, "rb") as f: self.binary_masks.append(io.BytesIO(f.read())) def __len__(self): @@ -87,7 +90,7 @@ def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: width, height = pil_image.size newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) pil_image = pil_image.resize(newsize, resample=Image.Resampling.BILINEAR) - image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" + image = pil_to_numpy(pil_image) # # shape is (h, w) or (h, w, 3 or 4) and dtype == "uint8" if len(image.shape) == 2: image = image[:, :, None].repeat(3, axis=2) assert len(image.shape) == 3 @@ -120,7 +123,9 @@ def get_image_uint8(self, image_idx: int) -> UInt8[Tensor, "image_height image_w Args: image_idx: The image index in the dataset. """ - image = torch.from_numpy(self.get_numpy_image(image_idx)) # removed astype(np.uint8) because get_numpy_image returns uint8 + image = torch.from_numpy( + self.get_numpy_image(image_idx) + ) # removed astype(np.uint8) because get_numpy_image returns uint8 if self._dataparser_outputs.alpha_color is not None and image.shape[-1] == 4: assert (self._dataparser_outputs.alpha_color >= 0).all() and ( self._dataparser_outputs.alpha_color <= 1 diff --git a/nerfstudio/data/utils/data_utils.py b/nerfstudio/data/utils/data_utils.py index b1eb3b867d..97f7a53f0f 100644 --- a/nerfstudio/data/utils/data_utils.py +++ b/nerfstudio/data/utils/data_utils.py @@ -15,15 +15,17 @@ """Utility functions to allow easy re-use of common operations across dataloaders""" from pathlib import Path -from typing import List, Tuple, Union, IO +from typing import IO, List, Tuple, Union import cv2 import numpy as np import torch from PIL import Image +from PIL.Image import Image as PILImage from torch.profiler import record_function -def pil_to_numpy(im: Image) -> np.ndarray: + +def pil_to_numpy(im: PILImage) -> np.ndarray: """Converts a PIL Image object to a NumPy array. Args: @@ -47,8 +49,8 @@ def pil_to_numpy(im: Image) -> np.ndarray: bufsize, s, offset = 65536, 0, 0 while not s: - l, s, d = e.encode(bufsize) - mem[offset:offset + len(d)] = d + _, s, d = e.encode(bufsize) + mem[offset : offset + len(d)] = d offset += len(d) if s < 0: raise RuntimeError("encoder error %d in tobytes" % s) @@ -119,6 +121,7 @@ def get_depth_image_from_path( image = cv2.resize(image, (width, height), interpolation=interpolation) return torch.from_numpy(image[:, :, np.newaxis]) + def identity_collate(x): """This function does nothing but serves to help our dataloaders have a pickleable function, as lambdas are not pickleable""" - return x \ No newline at end of file + return x diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index dd2c81e730..10faf3bd56 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -18,6 +18,7 @@ # for multithreading import concurrent.futures +import math import multiprocessing import random from abc import abstractmethod @@ -28,7 +29,7 @@ import numpy as np import torch from rich.progress import track -from torch.utils.data import Dataset +from torch.utils.data import Dataset, get_worker_info from torch.utils.data.dataloader import DataLoader from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper @@ -396,11 +397,6 @@ def __iter__(self): yield collated_batch -import math - -from torch.utils.data import Dataset - - class RayBatchStream(torch.utils.data.IterableDataset): """Wrapper around Pytorch's IterableDataset to generate the next batch of rays (next RayBundle) and corresponding labels with multiple parallel workers. @@ -527,7 +523,7 @@ def _get_collated_batch(self, indices=None): def __iter__(self): """This implementation allows every worker only cache the indices of the images they will use to generate rays to conserve RAM memory.""" - worker_info = torch.utils.data.get_worker_info() + worker_info = get_worker_info() if worker_info is not None: # if we have multiple processes per_worker = int(math.ceil(len(self.input_dataset) / float(worker_info.num_workers))) slice_start = worker_info.id * per_worker @@ -605,9 +601,8 @@ def __init__( self.custom_view_processor = custom_view_processor def __iter__(self): - # print(self.input_dataset.cameras.device) prints cpu dataset_indices = list(range(len(self.input_dataset))) - worker_info = torch.utils.data.get_worker_info() + worker_info = get_worker_info() if worker_info is not None: # if we have multiple processes per_worker = int(math.ceil(len(dataset_indices) / float(worker_info.num_workers))) slice_start = worker_info.id * per_worker From 2e64120f9fe286155e11148b3ee2a99cecd9f28d Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 28 Oct 2024 04:53:26 -0700 Subject: [PATCH 76/78] further pyright fixing --- nerfstudio/data/datamanagers/base_datamanager.py | 2 +- .../data/datamanagers/full_images_datamanager.py | 15 +++++++++++---- nerfstudio/data/utils/dataloaders.py | 15 +++++++-------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index b968fc25bf..0df4a97c44 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -343,7 +343,7 @@ class VanillaDataManagerConfig(DataManagerConfig): dataloader_num_workers: int = 0 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = None + prefetch_factor: int | None = None """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" cache_compressed_images: bool = False diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index e44d616977..4fe7837965 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -353,7 +353,14 @@ def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: Pretends to be the dataloader for evaluation, it returns a list of (camera, data) tuples """ if self.config.cache_images == "disk": - return self.iter_eval_image_dataloader + dataloader = DataLoader( + self.eval_imagebatch_stream, + batch_size=1, + num_workers=0, + collate_fn=identity_collate, + ) + return [batch[0] for batch in dataloader] + image_indices = [i for i in range(len(self.eval_dataset))] data = [d.copy() for d in self.cached_eval] _cameras = deepcopy(self.eval_dataset.cameras).to(self.device) @@ -371,10 +378,10 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: """ return {} - def get_train_rays_per_batch(self): + def get_train_rays_per_batch(self) -> int: """Returns resolution of the image returned from datamanager.""" camera = self.train_dataset.cameras[0].reshape(()) - return camera.width.item() * camera.height.item() + return int(camera.width[0].item() * camera.height[0].item()) def next_train(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next training batch @@ -408,7 +415,7 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]: Returns a Camera instead of raybundle""" self.eval_count += 1 if self.config.cache_images == "disk": - camera, data = next(self.iter_train_image_dataloader)[0] + camera, data = next(self.iter_eval_image_dataloader)[0] return camera, data return self.next_eval_image(step=step) diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 10faf3bd56..c1506c47ab 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -29,7 +29,7 @@ import numpy as np import torch from rich.progress import track -from torch.utils.data import Dataset, get_worker_info +from torch.utils.data import Dataset, IterableDataset, get_worker_info from torch.utils.data.dataloader import DataLoader from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper @@ -237,7 +237,6 @@ def _undistort_image( map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, - borderValue=0, ) / 255.0 ).bool()[..., None] @@ -265,7 +264,7 @@ def undistort_view( f'does not match the camera parameters ({camera.width.item(), camera.height.item()}), idx = {idx}' ) if camera.distortion_params is None or torch.all(camera.distortion_params == 0): - return data + return camera, data K = camera.get_intrinsics_matrices().numpy() distortion_params = camera.distortion_params.numpy() image = data["image"].numpy() @@ -397,7 +396,7 @@ def __iter__(self): yield collated_batch -class RayBatchStream(torch.utils.data.IterableDataset): +class RayBatchStream(IterableDataset): """Wrapper around Pytorch's IterableDataset to generate the next batch of rays (next RayBundle) and corresponding labels with multiple parallel workers. @@ -451,11 +450,11 @@ def __init__( """Each worker has a self._cached_collated_batch contains a collated batch of images cached in RAM for a specific worker that's ready for pixel sampling.""" self.pixel_sampler_config: PixelSamplerConfig = PixelSamplerConfig() """Specifies the pixel sampler config used to sample pixels from images. Each worker will have its own pixel sampler""" - self.ray_generator: RayGenerator = None + self.ray_generator: Optional[RayGenerator] = None """Each worker will have its own ray generator, so this is set to None for now.""" self.custom_ray_processor = custom_ray_processor - def _get_pixel_sampler(self, dataset: Dataset, num_rays_per_batch: int) -> PixelSampler: + def _get_pixel_sampler(self, dataset: InputDataset, num_rays_per_batch: int) -> PixelSampler: """copied from VanillaDataManager.""" from nerfstudio.cameras.cameras import CameraType @@ -579,7 +578,7 @@ def __iter__(self): yield ray_bundle, batch -class ImageBatchStream(torch.utils.data.IterableDataset): +class ImageBatchStream(IterableDataset): """ A wrapper of InputDataset that outputs undistorted full images and cameras. This makes the datamanager more lightweight since we don't have to do generate rays. Useful for full-image @@ -622,7 +621,7 @@ def __iter__(self): r.shuffle(worker_indices) i = 0 idx = worker_indices[i] # idx refers to the actual datapoint index this worker will retrieve - camera, data = undistort_view(idx, self.input_dataset, self.cache_images_type) + camera, data = undistort_view(idx, self.input_dataset, self.cache_images_type) # type: ignore if camera.metadata is None: camera.metadata = {} camera.metadata["cam_idx"] = idx From e9c2fd6f773d19b6040231423e6fc751c4cb6eab Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Mon, 28 Oct 2024 04:59:02 -0700 Subject: [PATCH 77/78] another pyright fixing --- nerfstudio/configs/method_configs.py | 2 -- nerfstudio/data/datamanagers/base_datamanager.py | 1 - 2 files changed, 3 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index e641eb458b..daca1425de 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -29,7 +29,6 @@ from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager -from nerfstudio.data.datamanagers.parallel_full_images_datamanager import ParallelFullImageDatamanager from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig @@ -666,7 +665,6 @@ mixed_precision=False, pipeline=VanillaPipelineConfig( datamanager=FullImageDatamanagerConfig( - _target=ParallelFullImageDatamanager[InputDataset], dataparser=NerfstudioDataParserConfig(load_3D_points=True, downscale_factor=1), # dataparser=NerfstudioDataParserConfig(load_3D_points=True), cache_images_type="uint8", diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index 0df4a97c44..1fefaedc24 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -162,7 +162,6 @@ class DataManager: train_sampler: Optional[DistributedSampler] = None eval_sampler: Optional[DistributedSampler] = None includes_time: bool = False - test_mode: Literal["test", "val", "inference"] = "val" def __init__(self): """Constructor for the DataManager class. From e4dc9f9015ae4128dd7418cd7019ab3927b71d97 Mon Sep 17 00:00:00 2001 From: AntonioMacaronio Date: Fri, 1 Nov 2024 02:50:31 -0700 Subject: [PATCH 78/78] fixing pyright error, camera optimization no longer part of datamanager --- nerfstudio/viewer/viewer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index bc58043aa6..8db05c0dda 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -372,9 +372,7 @@ def update_camera_poses(self): # TODO this fn accounts for like ~5% of total train time # Update the train camera locations based on optimization assert self.camera_handles is not None - if hasattr(self.pipeline.datamanager, "train_camera_optimizer"): - camera_optimizer = self.pipeline.datamanager.train_camera_optimizer - elif hasattr(self.pipeline.model, "camera_optimizer"): + if hasattr(self.pipeline.model, "camera_optimizer"): camera_optimizer = self.pipeline.model.camera_optimizer else: return