diff --git a/nerfacc/grid.py b/nerfacc/grid.py index 29f3be78..9c3c2eac 100644 --- a/nerfacc/grid.py +++ b/nerfacc/grid.py @@ -169,9 +169,9 @@ def __init__( grid_coords = _meshgrid3d(resolution).reshape( self.num_cells, self.NUM_DIM ) - self.register_buffer("grid_coords", grid_coords) + self.register_buffer("grid_coords", grid_coords, persistent=False) grid_indices = torch.arange(self.num_cells) - self.register_buffer("grid_indices", grid_indices) + self.register_buffer("grid_indices", grid_indices, persistent=False) @torch.no_grad() def _get_all_cells(self) -> torch.Tensor: