diff --git a/nerfacc/vol_rendering.py b/nerfacc/vol_rendering.py index 193ebd89..2f464fc3 100644 --- a/nerfacc/vol_rendering.py +++ b/nerfacc/vol_rendering.py @@ -190,11 +190,10 @@ def accumulate_along_rays( n_rays = int(ray_indices.max()) + 1 # assert n_rays > ray_indices.max() - index = ray_indices[:, None].expand(-1, src.shape[-1]) outputs = torch.zeros( (n_rays, src.shape[-1]), device=src.device, dtype=src.dtype ) - outputs.scatter_add_(0, index, src) + outputs.index_add_(0, ray_indices, src) return outputs